async_compression/futures/write/generic/
encoder.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::codecs::Encode;
8use crate::core::util::PartialBuffer;
9use crate::futures::write::{AsyncBufWrite, BufWriter};
10use futures_core::ready;
11use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite, IoSliceMut};
12use pin_project_lite::pin_project;
13
14#[derive(Debug)]
15enum State {
16    Encoding,
17    Finishing,
18    Done,
19}
20
21pin_project! {
22    #[derive(Debug)]
23    pub struct Encoder<W, E> {
24        #[pin]
25        writer: BufWriter<W>,
26        encoder: E,
27        state: State,
28    }
29}
30
31impl<W, E> Encoder<W, E> {
32    pub fn get_ref(&self) -> &W {
33        self.writer.get_ref()
34    }
35
36    pub fn get_mut(&mut self) -> &mut W {
37        self.writer.get_mut()
38    }
39
40    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
41        self.project().writer.get_pin_mut()
42    }
43
44    pub(crate) fn get_encoder_ref(&self) -> &E {
45        &self.encoder
46    }
47
48    pub fn into_inner(self) -> W {
49        self.writer.into_inner()
50    }
51}
52
53impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
54    pub fn new(writer: W, encoder: E) -> Self {
55        Self {
56            writer: BufWriter::new(writer),
57            encoder,
58            state: State::Encoding,
59        }
60    }
61
62    pub fn with_capacity(writer: W, encoder: E, cap: usize) -> Self {
63        Self {
64            writer: BufWriter::with_capacity(cap, writer),
65            encoder,
66            state: State::Encoding,
67        }
68    }
69
70    fn do_poll_write(
71        self: Pin<&mut Self>,
72        cx: &mut Context<'_>,
73        input: &mut PartialBuffer<&[u8]>,
74    ) -> Poll<io::Result<()>> {
75        let mut this = self.project();
76
77        loop {
78            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
79            let mut output = PartialBuffer::new(output);
80
81            *this.state = match this.state {
82                State::Encoding => {
83                    this.encoder.encode(input, &mut output)?;
84                    State::Encoding
85                }
86
87                State::Finishing | State::Done => {
88                    return Poll::Ready(Err(io::Error::other("Write after close")))
89                }
90            };
91
92            let produced = output.written().len();
93            this.writer.as_mut().produce(produced);
94
95            if input.unwritten().is_empty() {
96                return Poll::Ready(Ok(()));
97            }
98        }
99    }
100
101    fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
102        let mut this = self.project();
103
104        loop {
105            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
106            let mut output = PartialBuffer::new(output);
107
108            let done = match this.state {
109                State::Encoding => this.encoder.flush(&mut output)?,
110
111                State::Finishing | State::Done => {
112                    return Poll::Ready(Err(io::Error::other("Flush after close")))
113                }
114            };
115
116            let produced = output.written().len();
117            this.writer.as_mut().produce(produced);
118
119            if done {
120                return Poll::Ready(Ok(()));
121            }
122        }
123    }
124
125    fn do_poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
126        let mut this = self.project();
127
128        loop {
129            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
130            let mut output = PartialBuffer::new(output);
131
132            *this.state = match this.state {
133                State::Encoding | State::Finishing => {
134                    if this.encoder.finish(&mut output)? {
135                        State::Done
136                    } else {
137                        State::Finishing
138                    }
139                }
140
141                State::Done => State::Done,
142            };
143
144            let produced = output.written().len();
145            this.writer.as_mut().produce(produced);
146
147            if let State::Done = this.state {
148                return Poll::Ready(Ok(()));
149            }
150        }
151    }
152}
153
154impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
155    fn poll_write(
156        self: Pin<&mut Self>,
157        cx: &mut Context<'_>,
158        buf: &[u8],
159    ) -> Poll<io::Result<usize>> {
160        if buf.is_empty() {
161            return Poll::Ready(Ok(0));
162        }
163
164        let mut input = PartialBuffer::new(buf);
165
166        match self.do_poll_write(cx, &mut input)? {
167            Poll::Pending if input.written().is_empty() => Poll::Pending,
168            _ => Poll::Ready(Ok(input.written().len())),
169        }
170    }
171
172    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173        ready!(self.as_mut().do_poll_flush(cx))?;
174        ready!(self.project().writer.as_mut().poll_flush(cx))?;
175        Poll::Ready(Ok(()))
176    }
177
178    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
179        ready!(self.as_mut().do_poll_close(cx))?;
180        ready!(self.project().writer.as_mut().poll_close(cx))?;
181        Poll::Ready(Ok(()))
182    }
183}
184
185impl<W: AsyncRead, E> AsyncRead for Encoder<W, E> {
186    fn poll_read(
187        self: Pin<&mut Self>,
188        cx: &mut Context<'_>,
189        buf: &mut [u8],
190    ) -> Poll<io::Result<usize>> {
191        self.get_pin_mut().poll_read(cx, buf)
192    }
193
194    fn poll_read_vectored(
195        self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        bufs: &mut [IoSliceMut<'_>],
198    ) -> Poll<io::Result<usize>> {
199        self.get_pin_mut().poll_read_vectored(cx, bufs)
200    }
201}
202
203impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
204    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
205        self.get_pin_mut().poll_fill_buf(cx)
206    }
207
208    fn consume(self: Pin<&mut Self>, amt: usize) {
209        self.get_pin_mut().consume(amt)
210    }
211}