diff options
Diffstat (limited to 'tokio/src/io/util/buf_writer.rs')
-rw-r--r-- | tokio/src/io/util/buf_writer.rs | 196 |
1 files changed, 196 insertions, 0 deletions
diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs new file mode 100644 index 00000000..e17281ac --- /dev/null +++ b/tokio/src/io/util/buf_writer.rs @@ -0,0 +1,196 @@ +use crate::io::util::DEFAULT_BUF_SIZE; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; + +use pin_project::{pin_project, project}; +use std::fmt; +use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Wraps a writer and buffers its output. +/// +/// It can be excessively inefficient to work directly with something that +/// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and +/// writes it to an underlying writer in large, infrequent batches. +/// +/// `BufWriter` can improve the speed of programs that make *small* and +/// *repeated* write calls to the same file or network socket. It does not +/// help when writing very large amounts at once, or writing just one or a few +/// times. It also provides no advantage when writing to a destination that is +/// in memory, like a `Vec<u8>`. +/// +/// When the `BufWriter` is dropped, the contents of its buffer will be +/// discarded. Creating multiple instances of a `BufWriter` on the same +/// stream can cause data loss. If you need to write out the contents of its +/// buffer, you must manually call flush before the writer is dropped. +/// +/// [`AsyncWrite`]: AsyncWrite +/// [`flush`]: super::AsyncWriteExt::flush +/// +// TODO: Examples +#[pin_project] +pub struct BufWriter<W> { + #[pin] + pub(super) inner: W, + pub(super) buf: Vec<u8>, + pub(super) written: usize, +} + +impl<W: AsyncWrite> BufWriter<W> { + /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, + /// but may change in the future. + pub fn new(inner: W) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, inner) + } + + /// Creates a new `BufWriter` with the specified buffer capacity. + pub fn with_capacity(cap: usize, inner: W) -> Self { + Self { + inner, + buf: Vec::with_capacity(cap), + written: 0, + } + } + + #[project] + fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + #[project] + let BufWriter { + mut inner, + buf, + written, + } = self.project(); + + let len = buf.len(); + let mut ret = Ok(()); + while *written < len { + match ready!(inner.as_mut().poll_write(cx, &buf[*written..])) { + Ok(0) => { + ret = Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write the buffered data", + )); + break; + } + Ok(n) => *written += n, + Err(e) => { + ret = Err(e); + break; + } + } + } + if *written > 0 { + buf.drain(..*written); + } + *written = 0; + Poll::Ready(ret) + } + + /// Gets a reference to the underlying writer. + pub fn get_ref(&self) -> &W { + &self.inner + } + + /// Gets a mutable reference to the underlying writer. + /// + /// It is inadvisable to directly write to the underlying writer. + pub fn get_mut(&mut self) -> &mut W { + &mut self.inner + } + + /// Gets a pinned mutable reference to the underlying writer. + /// + /// It is inadvisable to directly write to the underlying writer. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().inner + } + + /// Consumes this `BufWriter`, returning the underlying writer. + /// + /// Note that any leftover data in the internal buffer is lost. + pub fn into_inner(self) -> W { + self.inner + } + + /// Returns a reference to the internally buffered data. + pub fn buffer(&self) -> &[u8] { + &self.buf + } +} + +impl<W: AsyncWrite> AsyncWrite for BufWriter<W> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + if self.buf.len() + buf.len() > self.buf.capacity() { + ready!(self.as_mut().flush_buf(cx))?; + } + + let me = self.project(); + if buf.len() >= me.buf.capacity() { + me.inner.poll_write(cx, buf) + } else { + Poll::Ready(me.buf.write(buf)) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + ready!(self.as_mut().flush_buf(cx))?; + self.get_pin_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + ready!(self.as_mut().flush_buf(cx))?; + self.get_pin_mut().poll_shutdown(cx) + } +} + +impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + self.get_pin_mut().poll_read(cx, buf) + } + + // we can't skip unconditionally because of the large buffer case in read. + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.get_ref().prepare_uninitialized_buffer(buf) + } +} + +impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + self.get_pin_mut().poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.get_pin_mut().consume(amt) + } +} + +impl<W: fmt::Debug> fmt::Debug for BufWriter<W> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BufWriter") + .field("writer", &self.inner) + .field( + "buffer", + &format_args!("{}/{}", self.buf.len(), self.buf.capacity()), + ) + .field("written", &self.written) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<BufWriter<()>>(); + } +} |