tokio/io/util/
chain.rs

1use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
2
3use pin_project_lite::pin_project;
4use std::fmt;
5use std::io;
6use std::pin::Pin;
7use std::task::{ready, Context, Poll};
8
9pin_project! {
10    /// Stream for the [`chain`](super::AsyncReadExt::chain) method.
11    #[must_use = "streams do nothing unless polled"]
12    #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
13    pub struct Chain<T, U> {
14        #[pin]
15        first: T,
16        #[pin]
17        second: U,
18        done_first: bool,
19    }
20}
21
22pub(super) fn chain<T, U>(first: T, second: U) -> Chain<T, U>
23where
24    T: AsyncRead,
25    U: AsyncRead,
26{
27    Chain {
28        first,
29        second,
30        done_first: false,
31    }
32}
33
34impl<T, U> Chain<T, U>
35where
36    T: AsyncRead,
37    U: AsyncRead,
38{
39    /// Gets references to the underlying readers in this `Chain`.
40    pub fn get_ref(&self) -> (&T, &U) {
41        (&self.first, &self.second)
42    }
43
44    /// Gets mutable references to the underlying readers in this `Chain`.
45    ///
46    /// Care should be taken to avoid modifying the internal I/O state of the
47    /// underlying readers as doing so may corrupt the internal state of this
48    /// `Chain`.
49    pub fn get_mut(&mut self) -> (&mut T, &mut U) {
50        (&mut self.first, &mut self.second)
51    }
52
53    /// Gets pinned mutable references to the underlying readers in this `Chain`.
54    ///
55    /// Care should be taken to avoid modifying the internal I/O state of the
56    /// underlying readers as doing so may corrupt the internal state of this
57    /// `Chain`.
58    pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
59        let me = self.project();
60        (me.first, me.second)
61    }
62
63    /// Consumes the `Chain`, returning the wrapped readers.
64    pub fn into_inner(self) -> (T, U) {
65        (self.first, self.second)
66    }
67}
68
69impl<T, U> fmt::Debug for Chain<T, U>
70where
71    T: fmt::Debug,
72    U: fmt::Debug,
73{
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        f.debug_struct("Chain")
76            .field("t", &self.first)
77            .field("u", &self.second)
78            .finish()
79    }
80}
81
82impl<T, U> AsyncRead for Chain<T, U>
83where
84    T: AsyncRead,
85    U: AsyncRead,
86{
87    fn poll_read(
88        self: Pin<&mut Self>,
89        cx: &mut Context<'_>,
90        buf: &mut ReadBuf<'_>,
91    ) -> Poll<io::Result<()>> {
92        let me = self.project();
93
94        if !*me.done_first {
95            let rem = buf.remaining();
96            ready!(me.first.poll_read(cx, buf))?;
97            if buf.remaining() == rem {
98                *me.done_first = true;
99            } else {
100                return Poll::Ready(Ok(()));
101            }
102        }
103        me.second.poll_read(cx, buf)
104    }
105}
106
107impl<T, U> AsyncBufRead for Chain<T, U>
108where
109    T: AsyncBufRead,
110    U: AsyncBufRead,
111{
112    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
113        let me = self.project();
114
115        if !*me.done_first {
116            match ready!(me.first.poll_fill_buf(cx)?) {
117                [] => {
118                    *me.done_first = true;
119                }
120                buf => return Poll::Ready(Ok(buf)),
121            }
122        }
123        me.second.poll_fill_buf(cx)
124    }
125
126    fn consume(self: Pin<&mut Self>, amt: usize) {
127        let me = self.project();
128        if !*me.done_first {
129            me.first.consume(amt)
130        } else {
131            me.second.consume(amt)
132        }
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn assert_unpin() {
142        crate::is_unpin::<Chain<(), ()>>();
143    }
144}