From 647299866a2262c8a1183adad73673e5803293ed Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Thu, 3 Dec 2020 11:19:16 -0800 Subject: util: add writev-aware `poll_write_buf` (#3156) ## Motivation In Tokio 0.2, `AsyncRead` and `AsyncWrite` had `poll_write_buf` and `poll_read_buf` methods for reading and writing to implementers of `bytes` `Buf` and `BufMut` traits. In 0.3, these were removed, but `poll_read_buf` was added as a free function in `tokio-util`. However, there is currently no `poll_write_buf`. Now that `AsyncWrite` has regained support for vectored writes in #3149, there's a lot of potential benefit in having a `poll_write_buf` that uses vectored writes when supported and non-vectored writes when not supported, so that users don't have to reimplement this. ## Solution This PR adds a `poll_write_buf` function to `tokio_util::io`, analogous to the existing `poll_read_buf` function. This function writes from a `Buf` to an `AsyncWrite`, advancing the `Buf`'s internal cursor. In addition, when the `AsyncWrite` supports vectored writes (i.e. its `is_write_vectored` method returns `true`), it will use vectored IO. I copied the documentation for this functions from the docs from Tokio 0.2's `AsyncWrite::poll_write_buf` , with some minor modifications as appropriate. Finally, I fixed a minor issue in the existing docs for `poll_read_buf` and `read_buf`, and updated `tokio_util::codec` to use `poll_write_buf`. Signed-off-by: Eliza Weisman --- tokio-util/Cargo.toml | 2 +- tokio-util/src/codec/framed_impl.rs | 8 ++-- tokio-util/src/io/mod.rs | 2 +- tokio-util/src/io/read_buf.rs | 4 +- tokio-util/src/lib.rs | 75 ++++++++++++++++++++++++++++++++++--- 5 files changed, 77 insertions(+), 14 deletions(-) diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index 1c0ee628..7a1e39c8 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -34,7 +34,7 @@ io = [] rt = ["tokio/rt"] [dependencies] -tokio = { version = "0.3.0", path = "../tokio" } +tokio = { version = "0.3.4", path = "../tokio" } bytes = "0.6.0" futures-core = "0.3.0" diff --git a/tokio-util/src/codec/framed_impl.rs b/tokio-util/src/codec/framed_impl.rs index e8b29999..207e198d 100644 --- a/tokio-util/src/codec/framed_impl.rs +++ b/tokio-util/src/codec/framed_impl.rs @@ -6,7 +6,7 @@ use tokio::{ stream::Stream, }; -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use futures_core::ready; use futures_sink::Sink; use log::trace; @@ -189,6 +189,7 @@ where } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use crate::util::poll_write_buf; trace!("flushing framed transport"); let mut pinned = self.project(); @@ -196,8 +197,7 @@ where let WriteFrame { buffer } = pinned.state.borrow_mut(); trace!("writing; remaining={}", buffer.len()); - let buf = &buffer; - let n = ready!(pinned.inner.as_mut().poll_write(cx, &buf))?; + let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; if n == 0 { return Poll::Ready(Err(io::Error::new( @@ -207,8 +207,6 @@ where ) .into())); } - - pinned.state.borrow_mut().buffer.advance(n); } // Try flushing the underlying IO diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index eefd65a5..eec74448 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -13,4 +13,4 @@ mod stream_reader; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; -pub use crate::util::poll_read_buf; +pub use crate::util::{poll_read_buf, poll_write_buf}; diff --git a/tokio-util/src/io/read_buf.rs b/tokio-util/src/io/read_buf.rs index cc3c505f..a5d46a7d 100644 --- a/tokio-util/src/io/read_buf.rs +++ b/tokio-util/src/io/read_buf.rs @@ -5,9 +5,9 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::AsyncRead; -/// Read data from an `AsyncRead` into an implementer of the [`Buf`] trait. +/// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. /// -/// [`Buf`]: bytes::Buf +/// [`BufMut`]: bytes::BufMut /// /// # Example /// diff --git a/tokio-util/src/lib.rs b/tokio-util/src/lib.rs index c4d80440..15bfc1a2 100644 --- a/tokio-util/src/lib.rs +++ b/tokio-util/src/lib.rs @@ -55,18 +55,18 @@ pub mod time; #[cfg(any(feature = "io", feature = "codec"))] mod util { - use tokio::io::{AsyncRead, ReadBuf}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - use bytes::BufMut; + use bytes::{Buf, BufMut}; use futures_core::ready; - use std::io; + use std::io::{self, IoSlice}; use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; - /// Try to read data from an `AsyncRead` into an implementer of the [`Buf`] trait. + /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. /// - /// [`Buf`]: bytes::Buf + /// [`BufMut`]: bytes::Buf /// /// # Example /// @@ -132,4 +132,69 @@ mod util { Poll::Ready(Ok(n)) } + + /// Try to write data from an implementer of the [`Buf`] trait to an + /// [`AsyncWrite`], advancing the buffer's internal cursor. + /// + /// This function will use [vectored writes] when the [`AsyncWrite`] supports + /// vectored writes. + /// + /// # Examples + /// + /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements + /// [`Buf`]: + /// + /// ```no_run + /// use tokio_util::io::poll_write_buf; + /// use tokio::io; + /// use tokio::fs::File; + /// + /// use bytes::Buf; + /// use std::io::Cursor; + /// use std::pin::Pin; + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// let mut buf = Cursor::new(b"data to write"); + /// + /// // Loop until the entire contents of the buffer are written to + /// // the file. + /// while buf.has_remaining() { + /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; + /// } + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`Buf`]: bytes::Buf + /// [`AsyncWrite`]: tokio::io::AsyncWrite + /// [`File`]: tokio::fs::File + /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored + #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] + pub fn poll_write_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + const MAX_BUFS: usize = 64; + + if !buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = if io.is_write_vectored() { + let mut slices = [IoSlice::new(&[]); MAX_BUFS]; + let cnt = buf.bytes_vectored(&mut slices); + ready!(io.poll_write_vectored(cx, &slices[..cnt]))? + } else { + ready!(io.poll_write(cx, buf.bytes()))? + }; + + buf.advance(n); + + Poll::Ready(Ok(n)) + } } -- cgit v1.2.3