async_compression/futures/bufread/generic/
encoder.rs

1use core::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5use std::io::Result;
6
7use crate::codecs::Encode;
8use crate::core::util::PartialBuffer;
9use futures_core::ready;
10use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite, IoSlice};
11use pin_project_lite::pin_project;
12
13#[derive(Debug)]
14enum State {
15    Encoding,
16    Flushing,
17    Done,
18}
19
20pin_project! {
21    #[derive(Debug)]
22    pub struct Encoder<R, E> {
23        #[pin]
24        reader: R,
25        encoder: E,
26        state: State,
27    }
28}
29
30impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
31    pub fn new(reader: R, encoder: E) -> Self {
32        Self {
33            reader,
34            encoder,
35            state: State::Encoding,
36        }
37    }
38
39    pub fn with_capacity(reader: R, encoder: E, _cap: usize) -> Self {
40        Self::new(reader, encoder)
41    }
42}
43
44impl<R, E> Encoder<R, E> {
45    pub fn get_ref(&self) -> &R {
46        &self.reader
47    }
48
49    pub fn get_mut(&mut self) -> &mut R {
50        &mut self.reader
51    }
52
53    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
54        self.project().reader
55    }
56
57    pub(crate) fn get_encoder_ref(&self) -> &E {
58        &self.encoder
59    }
60
61    pub fn into_inner(self) -> R {
62        self.reader
63    }
64}
65
66impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
67    fn do_poll_read(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70        output: &mut PartialBuffer<&mut [u8]>,
71    ) -> Poll<Result<()>> {
72        let mut this = self.project();
73
74        loop {
75            *this.state = match this.state {
76                State::Encoding => {
77                    let input = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
78                    if input.is_empty() {
79                        State::Flushing
80                    } else {
81                        let mut input = PartialBuffer::new(input);
82                        this.encoder.encode(&mut input, output)?;
83                        let len = input.written().len();
84                        this.reader.as_mut().consume(len);
85                        State::Encoding
86                    }
87                }
88
89                State::Flushing => {
90                    if this.encoder.finish(output)? {
91                        State::Done
92                    } else {
93                        State::Flushing
94                    }
95                }
96
97                State::Done => State::Done,
98            };
99
100            if let State::Done = *this.state {
101                return Poll::Ready(Ok(()));
102            }
103            if output.unwritten().is_empty() {
104                return Poll::Ready(Ok(()));
105            }
106        }
107    }
108}
109
110impl<R: AsyncBufRead, E: Encode> AsyncRead for Encoder<R, E> {
111    fn poll_read(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &mut [u8],
115    ) -> Poll<Result<usize>> {
116        if buf.is_empty() {
117            return Poll::Ready(Ok(0));
118        }
119
120        let mut output = PartialBuffer::new(buf);
121        match self.do_poll_read(cx, &mut output)? {
122            Poll::Pending if output.written().is_empty() => Poll::Pending,
123            _ => Poll::Ready(Ok(output.written().len())),
124        }
125    }
126}
127
128impl<R: AsyncWrite, E> AsyncWrite for Encoder<R, E> {
129    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
130        self.get_pin_mut().poll_write(cx, buf)
131    }
132
133    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
134        self.get_pin_mut().poll_flush(cx)
135    }
136
137    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
138        self.get_pin_mut().poll_close(cx)
139    }
140
141    fn poll_write_vectored(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        bufs: &[IoSlice<'_>],
145    ) -> Poll<Result<usize>> {
146        self.get_pin_mut().poll_write_vectored(cx, bufs)
147    }
148}