1use super::{BorrowedBuf, BufReader, BufWriter, DEFAULT_BUF_SIZE, Read, Result, Write};
2use crate::alloc::Allocator;
3use crate::cmp;
4use crate::collections::VecDeque;
5use crate::io::IoSlice;
6use crate::mem::MaybeUninit;
7
8#[cfg(test)]
9mod tests;
10
11#[stable(feature = "rust1", since = "1.0.0")]
61pub fn copy<R: ?Sized, W: ?Sized>(reader: &mut R, writer: &mut W) -> Result<u64>
62where
63    R: Read,
64    W: Write,
65{
66    cfg_if::cfg_if! {
67        if #[cfg(any(target_os = "linux", target_os = "android"))] {
68            crate::sys::kernel_copy::copy_spec(reader, writer)
69        } else {
70            generic_copy(reader, writer)
71        }
72    }
73}
74
75pub(crate) fn generic_copy<R: ?Sized, W: ?Sized>(reader: &mut R, writer: &mut W) -> Result<u64>
78where
79    R: Read,
80    W: Write,
81{
82    let read_buf = BufferedReaderSpec::buffer_size(reader);
83    let write_buf = BufferedWriterSpec::buffer_size(writer);
84
85    if read_buf >= DEFAULT_BUF_SIZE && read_buf >= write_buf {
86        return BufferedReaderSpec::copy_to(reader, writer);
87    }
88
89    BufferedWriterSpec::copy_from(writer, reader)
90}
91
92trait BufferedReaderSpec {
96    fn buffer_size(&self) -> usize;
97
98    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64>;
99}
100
101impl<T> BufferedReaderSpec for T
102where
103    Self: Read,
104    T: ?Sized,
105{
106    #[inline]
107    default fn buffer_size(&self) -> usize {
108        0
109    }
110
111    default fn copy_to(&mut self, _to: &mut (impl Write + ?Sized)) -> Result<u64> {
112        unreachable!("only called from specializations")
113    }
114}
115
116impl BufferedReaderSpec for &[u8] {
117    fn buffer_size(&self) -> usize {
118        usize::MAX
121    }
122
123    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
124        let len = self.len();
125        to.write_all(self)?;
126        *self = &self[len..];
127        Ok(len as u64)
128    }
129}
130
131impl<A: Allocator> BufferedReaderSpec for VecDeque<u8, A> {
132    fn buffer_size(&self) -> usize {
133        usize::MAX
136    }
137
138    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
139        let len = self.len();
140        let (front, back) = self.as_slices();
141        let bufs = &mut [IoSlice::new(front), IoSlice::new(back)];
142        to.write_all_vectored(bufs)?;
143        self.clear();
144        Ok(len as u64)
145    }
146}
147
148impl<I> BufferedReaderSpec for BufReader<I>
149where
150    Self: Read,
151    I: ?Sized,
152{
153    fn buffer_size(&self) -> usize {
154        self.capacity()
155    }
156
157    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
158        let mut len = 0;
159
160        loop {
161            match self.read(&mut []) {
166                Ok(_) => {}
167                Err(e) if e.is_interrupted() => continue,
168                Err(e) => return Err(e),
169            }
170            let buf = self.buffer();
171            if self.buffer().len() == 0 {
172                return Ok(len);
173            }
174
175            to.write_all(buf)?;
181            len += buf.len() as u64;
182            self.discard_buffer();
183        }
184    }
185}
186
187trait BufferedWriterSpec: Write {
190    fn buffer_size(&self) -> usize;
191
192    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64>;
193}
194
195impl<W: Write + ?Sized> BufferedWriterSpec for W {
196    #[inline]
197    default fn buffer_size(&self) -> usize {
198        0
199    }
200
201    default fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
202        stack_buffer_copy(reader, self)
203    }
204}
205
206impl<I: Write + ?Sized> BufferedWriterSpec for BufWriter<I> {
207    fn buffer_size(&self) -> usize {
208        self.capacity()
209    }
210
211    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
212        if self.capacity() < DEFAULT_BUF_SIZE {
213            return stack_buffer_copy(reader, self);
214        }
215
216        let mut len = 0;
217        let mut init = 0;
218
219        loop {
220            let buf = self.buffer_mut();
221            let mut read_buf: BorrowedBuf<'_> = buf.spare_capacity_mut().into();
222
223            unsafe {
224                read_buf.set_init(init);
226            }
227
228            if read_buf.capacity() >= DEFAULT_BUF_SIZE {
229                let mut cursor = read_buf.unfilled();
230                match reader.read_buf(cursor.reborrow()) {
231                    Ok(()) => {
232                        let bytes_read = cursor.written();
233
234                        if bytes_read == 0 {
235                            return Ok(len);
236                        }
237
238                        init = read_buf.init_len() - bytes_read;
239                        len += bytes_read as u64;
240
241                        unsafe { buf.set_len(buf.len() + bytes_read) };
243
244                        }
247                    Err(ref e) if e.is_interrupted() => {}
248                    Err(e) => return Err(e),
249                }
250            } else {
251                self.flush_buf()?;
252                init = 0;
253            }
254        }
255    }
256}
257
258impl BufferedWriterSpec for Vec<u8> {
259    fn buffer_size(&self) -> usize {
260        cmp::max(DEFAULT_BUF_SIZE, self.capacity() - self.len())
261    }
262
263    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
264        reader.read_to_end(self).map(|bytes| u64::try_from(bytes).expect("usize overflowed u64"))
265    }
266}
267
268pub fn stack_buffer_copy<R: Read + ?Sized, W: Write + ?Sized>(
269    reader: &mut R,
270    writer: &mut W,
271) -> Result<u64> {
272    let buf: &mut [_] = &mut [MaybeUninit::uninit(); DEFAULT_BUF_SIZE];
273    let mut buf: BorrowedBuf<'_> = buf.into();
274
275    let mut len = 0;
276
277    loop {
278        match reader.read_buf(buf.unfilled()) {
279            Ok(()) => {}
280            Err(e) if e.is_interrupted() => continue,
281            Err(e) => return Err(e),
282        };
283
284        if buf.filled().is_empty() {
285            break;
286        }
287
288        len += buf.filled().len() as u64;
289        writer.write_all(buf.filled())?;
290        buf.clear();
291    }
292
293    Ok(len)
294}