1use crate::{
4 Encoding,
5 Error::{self, InvalidLength},
6 MIN_LINE_WIDTH, encoding,
7 line_ending::{CHAR_CR, CHAR_LF},
8};
9use core::{cmp, marker::PhantomData};
10
11#[cfg(feature = "alloc")]
12use {alloc::vec::Vec, core::iter};
13
14#[cfg(feature = "std")]
15use std::io;
16
17#[cfg(doc)]
18use crate::{Base64, Base64Unpadded};
19
20#[derive(Clone)]
25pub struct Decoder<'i, E: Encoding> {
26 line: Line<'i>,
28
29 line_reader: LineReader<'i>,
31
32 remaining_len: usize,
34
35 block_buffer: BlockBuffer,
37
38 encoding: PhantomData<E>,
40}
41
42impl<'i, E: Encoding> Decoder<'i, E> {
43 pub fn new(input: &'i [u8]) -> Result<Self, Error> {
50 let line_reader = LineReader::new_unwrapped(input)?;
51 let remaining_len = line_reader.decoded_len::<E>()?;
52
53 Ok(Self {
54 line: Line::default(),
55 line_reader,
56 remaining_len,
57 block_buffer: BlockBuffer::default(),
58 encoding: PhantomData,
59 })
60 }
61
62 pub fn new_wrapped(input: &'i [u8], line_width: usize) -> Result<Self, Error> {
87 let line_reader = LineReader::new_wrapped(input, line_width)?;
88 let remaining_len = line_reader.decoded_len::<E>()?;
89
90 Ok(Self {
91 line: Line::default(),
92 line_reader,
93 remaining_len,
94 block_buffer: BlockBuffer::default(),
95 encoding: PhantomData,
96 })
97 }
98
99 pub fn decode<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8], Error> {
107 if self.is_finished() && !out.is_empty() {
108 return Err(InvalidLength);
109 }
110
111 let mut out_pos = 0;
112
113 while out_pos < out.len() {
114 if !self.block_buffer.is_empty() {
116 let out_rem = out.len().checked_sub(out_pos).ok_or(InvalidLength)?;
117 let bytes = self.block_buffer.take(out_rem)?;
118 out[out_pos..][..bytes.len()].copy_from_slice(bytes);
119 out_pos = out_pos.checked_add(bytes.len()).ok_or(InvalidLength)?;
120 }
121
122 if self.line.is_empty() && !self.line_reader.is_empty() {
124 self.advance_line()?;
125 }
126
127 let in_blocks = self.line.len() / 4;
129 let out_rem = out.len().checked_sub(out_pos).ok_or(InvalidLength)?;
130 let out_blocks = out_rem / 3;
131 let blocks = cmp::min(in_blocks, out_blocks);
132 let in_aligned = self.line.take(blocks.checked_mul(4).ok_or(InvalidLength)?);
133
134 if !in_aligned.is_empty() {
135 let out_buf = &mut out[out_pos..][..blocks.checked_mul(3).ok_or(InvalidLength)?];
136 let decoded_len = self.perform_decode(in_aligned, out_buf)?.len();
137 out_pos = out_pos.checked_add(decoded_len).ok_or(InvalidLength)?;
138 }
139
140 if out_pos < out.len() {
141 if self.is_finished() {
142 return Err(InvalidLength);
145 } else {
146 self.fill_block_buffer()?;
151 }
152 }
153 }
154
155 self.remaining_len = self
156 .remaining_len
157 .checked_sub(out.len())
158 .ok_or(InvalidLength)?;
159
160 Ok(out)
161 }
162
163 #[cfg(feature = "alloc")]
168 pub fn decode_to_end<'o>(&mut self, buf: &'o mut Vec<u8>) -> Result<&'o [u8], Error> {
169 let start_len = buf.len();
170 let remaining_len = self.remaining_len();
171 let total_len = start_len.checked_add(remaining_len).ok_or(InvalidLength)?;
172
173 if total_len > buf.capacity() {
174 buf.reserve(total_len.checked_sub(buf.capacity()).ok_or(InvalidLength)?);
175 }
176
177 buf.extend(iter::repeat_n(0, remaining_len));
179 self.decode(&mut buf[start_len..])?;
180 Ok(&buf[start_len..])
181 }
182
183 pub fn remaining_len(&self) -> usize {
187 self.remaining_len
188 }
189
190 pub fn is_finished(&self) -> bool {
192 self.line.is_empty() && self.line_reader.is_empty() && self.block_buffer.is_empty()
193 }
194
195 fn fill_block_buffer(&mut self) -> Result<(), Error> {
197 let mut buf = [0u8; BlockBuffer::SIZE];
198
199 let decoded = if self.line.len() < 4 && !self.line_reader.is_empty() {
200 let mut tmp = [0u8; 4];
202
203 let line_end = self.line.take(4);
205 tmp[..line_end.len()].copy_from_slice(line_end);
206
207 self.advance_line()?;
209 let len = 4usize.checked_sub(line_end.len()).ok_or(InvalidLength)?;
210 let line_begin = self.line.take(len);
211 tmp[line_end.len()..][..line_begin.len()].copy_from_slice(line_begin);
212
213 let tmp_len = line_begin
214 .len()
215 .checked_add(line_end.len())
216 .ok_or(InvalidLength)?;
217
218 self.perform_decode(&tmp[..tmp_len], &mut buf)
219 } else {
220 let block = self.line.take(4);
221 self.perform_decode(block, &mut buf)
222 }?;
223
224 self.block_buffer.fill(decoded)
225 }
226
227 fn advance_line(&mut self) -> Result<(), Error> {
229 debug_assert!(self.line.is_empty(), "expected line buffer to be empty");
230
231 if let Some(line) = self.line_reader.next().transpose()? {
232 self.line = line;
233 Ok(())
234 } else {
235 Err(InvalidLength)
236 }
237 }
238
239 fn perform_decode<'o>(&self, src: &[u8], dst: &'o mut [u8]) -> Result<&'o [u8], Error> {
241 if self.is_finished() {
242 E::decode(src, dst)
243 } else {
244 E::Unpadded::decode(src, dst)
245 }
246 }
247}
248
249#[cfg(feature = "std")]
250impl<E: Encoding> io::Read for Decoder<'_, E> {
251 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
252 if self.is_finished() {
253 return Ok(0);
254 }
255 let slice = match buf.get_mut(..self.remaining_len()) {
256 Some(bytes) => bytes,
257 None => buf,
258 };
259
260 self.decode(slice)?;
261 Ok(slice.len())
262 }
263
264 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
265 if self.is_finished() {
266 return Ok(0);
267 }
268 Ok(self.decode_to_end(buf)?.len())
269 }
270
271 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
272 self.decode(buf)?;
273 Ok(())
274 }
275}
276
277#[derive(Clone, Default, Debug)]
282struct BlockBuffer {
283 decoded: [u8; Self::SIZE],
285
286 length: usize,
288
289 position: usize,
291}
292
293impl BlockBuffer {
294 const SIZE: usize = 3;
296
297 fn fill(&mut self, decoded_input: &[u8]) -> Result<(), Error> {
299 debug_assert!(self.is_empty());
300
301 if decoded_input.len() > Self::SIZE {
302 return Err(InvalidLength);
303 }
304
305 self.position = 0;
306 self.length = decoded_input.len();
307 self.decoded[..decoded_input.len()].copy_from_slice(decoded_input);
308 Ok(())
309 }
310
311 fn take(&mut self, mut nbytes: usize) -> Result<&[u8], Error> {
316 debug_assert!(self.position <= self.length);
317 let start_pos = self.position;
318 let remaining_len = self.length.checked_sub(start_pos).ok_or(InvalidLength)?;
319
320 if nbytes > remaining_len {
321 nbytes = remaining_len;
322 }
323
324 self.position = self.position.checked_add(nbytes).ok_or(InvalidLength)?;
325 Ok(&self.decoded[start_pos..][..nbytes])
326 }
327
328 fn is_empty(&self) -> bool {
330 self.position == self.length
331 }
332}
333
334#[derive(Clone, Debug)]
336pub struct Line<'i> {
337 remaining: &'i [u8],
339}
340
341impl Default for Line<'_> {
342 fn default() -> Self {
343 Self::new(&[])
344 }
345}
346
347impl<'i> Line<'i> {
348 fn new(bytes: &'i [u8]) -> Self {
350 Self { remaining: bytes }
351 }
352
353 fn take(&mut self, nbytes: usize) -> &'i [u8] {
355 let (bytes, rest) = if nbytes < self.remaining.len() {
356 self.remaining.split_at(nbytes)
357 } else {
358 (self.remaining, [].as_ref())
359 };
360
361 self.remaining = rest;
362 bytes
363 }
364
365 fn slice_tail(&self, nbytes: usize) -> Result<&'i [u8], Error> {
367 let offset = self.len().checked_sub(nbytes).ok_or(InvalidLength)?;
368 self.remaining.get(offset..).ok_or(InvalidLength)
369 }
370
371 fn len(&self) -> usize {
373 self.remaining.len()
374 }
375
376 fn is_empty(&self) -> bool {
378 self.len() == 0
379 }
380
381 fn trim_end(&self) -> Self {
383 Line::new(match self.remaining {
384 [line @ .., CHAR_CR, CHAR_LF] => line,
385 [line @ .., CHAR_CR] => line,
386 [line @ .., CHAR_LF] => line,
387 line => line,
388 })
389 }
390}
391
392#[derive(Clone)]
394struct LineReader<'i> {
395 remaining: &'i [u8],
397
398 line_width: Option<usize>,
400}
401
402impl<'i> LineReader<'i> {
403 fn new_unwrapped(bytes: &'i [u8]) -> Result<Self, Error> {
405 if bytes.is_empty() {
406 Err(InvalidLength)
407 } else {
408 Ok(Self {
409 remaining: bytes,
410 line_width: None,
411 })
412 }
413 }
414
415 fn new_wrapped(bytes: &'i [u8], line_width: usize) -> Result<Self, Error> {
417 if line_width < MIN_LINE_WIDTH {
418 return Err(InvalidLength);
419 }
420
421 let mut reader = Self::new_unwrapped(bytes)?;
422 reader.line_width = Some(line_width);
423 Ok(reader)
424 }
425
426 fn is_empty(&self) -> bool {
428 self.remaining.is_empty()
429 }
430
431 fn decoded_len<E: Encoding>(&self) -> Result<usize, Error> {
433 let mut buffer = [0u8; 4];
434 let mut lines = self.clone();
435 let mut line = match lines.next().transpose()? {
436 Some(l) => l,
437 None => return Ok(0),
438 };
439 let mut base64_len = 0usize;
440
441 loop {
442 base64_len = base64_len.checked_add(line.len()).ok_or(InvalidLength)?;
443
444 match lines.next().transpose()? {
445 Some(l) => {
446 buffer.copy_from_slice(line.slice_tail(4)?);
449
450 line = l
451 }
452
453 None => {
458 let base64_last_block_len = match base64_len % 4 {
460 0 => 4,
461 n => n,
462 };
463
464 let decoded_len = encoding::decoded_len(
466 base64_len
467 .checked_sub(base64_last_block_len)
468 .ok_or(InvalidLength)?,
469 );
470
471 let mut out = [0u8; 3];
473 let last_block_len = if line.len() < base64_last_block_len {
474 let buffered_part_len = base64_last_block_len
475 .checked_sub(line.len())
476 .ok_or(InvalidLength)?;
477
478 let offset = 4usize.checked_sub(buffered_part_len).ok_or(InvalidLength)?;
479
480 for i in 0..buffered_part_len {
481 buffer[i] = buffer[offset.checked_add(i).ok_or(InvalidLength)?];
482 }
483
484 buffer[buffered_part_len..][..line.len()].copy_from_slice(line.remaining);
485 let buffer_len = buffered_part_len
486 .checked_add(line.len())
487 .ok_or(InvalidLength)?;
488
489 E::decode(&buffer[..buffer_len], &mut out)?.len()
490 } else {
491 let last_block = line.slice_tail(base64_last_block_len)?;
492 E::decode(last_block, &mut out)?.len()
493 };
494
495 return decoded_len.checked_add(last_block_len).ok_or(InvalidLength);
496 }
497 }
498 }
499 }
500}
501
502impl<'i> Iterator for LineReader<'i> {
503 type Item = Result<Line<'i>, Error>;
504
505 fn next(&mut self) -> Option<Result<Line<'i>, Error>> {
506 if let Some(line_width) = self.line_width {
507 let rest = match self.remaining.get(line_width..) {
508 None | Some([]) => {
509 if self.remaining.is_empty() {
510 return None;
511 } else {
512 let line = Line::new(self.remaining).trim_end();
513 self.remaining = &[];
514 return Some(Ok(line));
515 }
516 }
517 Some([CHAR_CR, CHAR_LF, rest @ ..]) => rest,
518 Some([CHAR_CR, rest @ ..]) => rest,
519 Some([CHAR_LF, rest @ ..]) => rest,
520 _ => {
521 return Some(Err(Error::InvalidEncoding));
523 }
524 };
525
526 let line = Line::new(&self.remaining[..line_width]);
527 self.remaining = rest;
528 Some(Ok(line))
529 } else if !self.remaining.is_empty() {
530 let line = Line::new(self.remaining).trim_end();
531 self.remaining = b"";
532
533 if line.is_empty() {
534 None
535 } else {
536 Some(Ok(line))
537 }
538 } else {
539 None
540 }
541 }
542}
543
544#[cfg(test)]
545#[allow(clippy::unwrap_used)]
546mod tests {
547 use crate::{Base64, Base64Unpadded, Decoder, alphabet::Alphabet, test_vectors::*};
548
549 #[cfg(feature = "std")]
550 use {alloc::vec::Vec, std::io::Read};
551
552 #[test]
553 fn decode_padded() {
554 decode_test(PADDED_BIN, || {
555 Decoder::<Base64>::new(PADDED_BASE64.as_bytes()).unwrap()
556 })
557 }
558
559 #[test]
560 fn decode_unpadded() {
561 decode_test(UNPADDED_BIN, || {
562 Decoder::<Base64Unpadded>::new(UNPADDED_BASE64.as_bytes()).unwrap()
563 })
564 }
565
566 #[test]
567 fn decode_multiline_padded() {
568 decode_test(MULTILINE_PADDED_BIN, || {
569 Decoder::<Base64>::new_wrapped(MULTILINE_PADDED_BASE64.as_bytes(), 70).unwrap()
570 })
571 }
572
573 #[test]
574 fn decode_multiline_unpadded() {
575 decode_test(MULTILINE_UNPADDED_BIN, || {
576 Decoder::<Base64Unpadded>::new_wrapped(MULTILINE_UNPADDED_BASE64.as_bytes(), 70)
577 .unwrap()
578 })
579 }
580
581 #[cfg(feature = "std")]
582 #[test]
583 fn read_multiline_padded() {
584 let mut decoder =
585 Decoder::<Base64>::new_wrapped(MULTILINE_PADDED_BASE64.as_bytes(), 70).unwrap();
586
587 let mut buf = Vec::new();
588 let len = decoder.read_to_end(&mut buf).unwrap();
589
590 assert_eq!(len, MULTILINE_PADDED_BIN.len());
591 assert_eq!(buf.as_slice(), MULTILINE_PADDED_BIN);
592 }
593
594 #[cfg(feature = "std")]
595 #[test]
596 fn decode_empty_at_end() {
597 let mut decoder = Decoder::<Base64>::new(b"AAAA").unwrap();
598
599 let mut buf = vec![0u8; 3];
601 assert_eq!(decoder.decode(&mut buf), Ok(&vec![0u8; 3][..]));
602
603 let mut buf: Vec<u8> = vec![];
605
606 assert_eq!(decoder.decode(&mut buf), Ok(&[][..]));
607 }
608
609 #[allow(clippy::arithmetic_side_effects)]
611 fn decode_test<'a, F, V>(expected: &[u8], f: F)
612 where
613 F: Fn() -> Decoder<'a, V>,
614 V: Alphabet,
615 {
616 for chunk_size in 1..expected.len() {
617 let mut decoder = f();
618 let mut remaining_len = decoder.remaining_len();
619 let mut buffer = [0u8; 1024];
620
621 for chunk in expected.chunks(chunk_size) {
622 assert!(!decoder.is_finished());
623 let decoded = decoder.decode(&mut buffer[..chunk.len()]).unwrap();
624 assert_eq!(chunk, decoded);
625
626 let dlen = decoded.len();
627 remaining_len -= dlen;
628 assert_eq!(remaining_len, decoder.remaining_len());
629 }
630
631 assert!(decoder.is_finished());
632 assert_eq!(decoder.remaining_len(), 0);
633 }
634 }
635}