From 34fcef258b84d17f8d418b39eb61fa07fa87c390 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 18 Nov 2020 10:41:47 -0800 Subject: io: add vectored writes to `AsyncWrite` (#3149) This adds `AsyncWrite::poll_write_vectored`, and implements it for `TcpStream` and `UnixStream`. Refs: #3135. --- tokio/src/io/async_write.rs | 135 +++++++++++++++++++++++++++++++++++++- tokio/src/io/poll_evented.rs | 13 ++++ tokio/src/net/tcp/split.rs | 12 ++++ tokio/src/net/tcp/split_owned.rs | 12 ++++ tokio/src/net/tcp/stream.rs | 20 ++++++ tokio/src/net/unix/split.rs | 12 ++++ tokio/src/net/unix/split_owned.rs | 12 ++++ tokio/src/net/unix/stream.rs | 22 ++++++- 8 files changed, 236 insertions(+), 2 deletions(-) diff --git a/tokio/src/io/async_write.rs b/tokio/src/io/async_write.rs index 66ba4bf3..569fb9c9 100644 --- a/tokio/src/io/async_write.rs +++ b/tokio/src/io/async_write.rs @@ -1,4 +1,4 @@ -use std::io; +use std::io::{self, IoSlice}; use std::ops::DerefMut; use std::pin::Pin; use std::task::{Context, Poll}; @@ -127,6 +127,55 @@ pub trait AsyncWrite { /// This function will panic if not called within the context of a future's /// task. fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + /// Like [`poll_write`], except that it writes from a slice of buffers. + /// + /// Data is copied from each buffer in order, with the final buffer + /// read from possibly being only partially consumed. This method must + /// behave as a call to [`write`] with the buffers concatenated would. + /// + /// The default implementation calls [`poll_write`] with either the first nonempty + /// buffer provided, or an empty one if none exists. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. + /// + /// If the object is not ready for writing, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker()`) to receive a notification when the object becomes + /// writable or is closed. + /// + /// # Note + /// + /// This should be implemented as a single "atomic" write action. If any + /// data has been partially written, it is wrong to return an error or + /// pending. + /// + /// [`poll_write`]: AsyncWrite::poll_write + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let buf = bufs + .iter() + .find(|b| !b.is_empty()) + .map_or(&[][..], |b| &**b); + self.poll_write(cx, buf) + } + + /// Determines if this writer has an efficient [`poll_write_vectored`] + /// implementation. + /// + /// If a writer does not override the default [`poll_write_vectored`] + /// implementation, code using it may want to avoid the method all together + /// and coalesce writes into a single buffer for higher performance. + /// + /// The default implementation returns `false`. + /// + /// [`poll_write_vectored`]: AsyncWrite::poll_write_vectored + fn is_write_vectored(&self) -> bool { + false + } } macro_rules! deref_async_write { @@ -139,6 +188,18 @@ macro_rules! deref_async_write { Pin::new(&mut **self).poll_write(cx, buf) } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut **self).poll_flush(cx) } @@ -170,6 +231,18 @@ where self.get_mut().as_mut().poll_write(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + self.get_mut().as_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.get_mut().as_mut().poll_flush(cx) } @@ -189,6 +262,18 @@ impl AsyncWrite for Vec { Poll::Ready(Ok(buf.len())) } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -207,6 +292,18 @@ impl AsyncWrite for io::Cursor<&mut [u8]> { Poll::Ready(io::Write::write(&mut *self, buf)) } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(io::Write::flush(&mut *self)) } @@ -225,6 +322,18 @@ impl AsyncWrite for io::Cursor<&mut Vec> { Poll::Ready(io::Write::write(&mut *self, buf)) } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(io::Write::flush(&mut *self)) } @@ -243,6 +352,18 @@ impl AsyncWrite for io::Cursor> { Poll::Ready(io::Write::write(&mut *self, buf)) } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(io::Write::flush(&mut *self)) } @@ -261,6 +382,18 @@ impl AsyncWrite for io::Cursor> { Poll::Ready(io::Write::write(&mut *self, buf)) } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(io::Write::flush(&mut *self)) } diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs index 803932ba..3a659610 100644 --- a/tokio/src/io/poll_evented.rs +++ b/tokio/src/io/poll_evented.rs @@ -163,6 +163,19 @@ feature! { use std::io::Write; self.registration.poll_write_io(cx, || self.io.as_ref().unwrap().write(buf)) } + + #[cfg(feature = "net")] + pub(crate) fn poll_write_vectored<'a>( + &'a self, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> + where + &'a E: io::Write + 'a, + { + use std::io::Write; + self.registration.poll_write_io(cx, || self.io.as_ref().unwrap().write_vectored(bufs)) + } } } diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs index 9a257f8b..28c94eb4 100644 --- a/tokio/src/net/tcp/split.rs +++ b/tokio/src/net/tcp/split.rs @@ -147,6 +147,18 @@ impl AsyncWrite for WriteHalf<'_> { self.0.poll_write_priv(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.0.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + #[inline] fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { // tcp flush is a no-op diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs index 4b4e2636..8d77c8ca 100644 --- a/tokio/src/net/tcp/split_owned.rs +++ b/tokio/src/net/tcp/split_owned.rs @@ -229,6 +229,18 @@ impl AsyncWrite for OwnedWriteHalf { self.inner.poll_write_priv(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.inner.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + #[inline] fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { // tcp flush is a no-op diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index de7b4213..28118f73 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -832,6 +832,14 @@ impl TcpStream { ) -> Poll> { self.io.poll_write(cx, buf) } + + pub(super) fn poll_write_vectored_priv( + &self, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.io.poll_write_vectored(cx, bufs) + } } impl TryFrom for TcpStream { @@ -867,6 +875,18 @@ impl AsyncWrite for TcpStream { self.poll_write_priv(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + true + } + #[inline] fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { // tcp flush is a no-op diff --git a/tokio/src/net/unix/split.rs b/tokio/src/net/unix/split.rs index 460bbc19..af9c7624 100644 --- a/tokio/src/net/unix/split.rs +++ b/tokio/src/net/unix/split.rs @@ -68,6 +68,18 @@ impl AsyncWrite for WriteHalf<'_> { self.0.poll_write_priv(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.0.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } diff --git a/tokio/src/net/unix/split_owned.rs b/tokio/src/net/unix/split_owned.rs index ab233072..5f0a2593 100644 --- a/tokio/src/net/unix/split_owned.rs +++ b/tokio/src/net/unix/split_owned.rs @@ -153,6 +153,18 @@ impl AsyncWrite for OwnedWriteHalf { self.inner.poll_write_priv(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.inner.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + #[inline] fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { // flush is a no-op diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index 1d840926..f9619942 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -172,6 +172,18 @@ impl AsyncWrite for UnixStream { self.poll_write_priv(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -199,7 +211,7 @@ impl UnixStream { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - // Safety: `UdpStream::read` correctly handles reads into uninitialized memory + // Safety: `UnixStream::read` correctly handles reads into uninitialized memory unsafe { self.io.poll_read(cx, buf) } } @@ -210,6 +222,14 @@ impl UnixStream { ) -> Poll> { self.io.poll_write(cx, buf) } + + pub(super) fn poll_write_vectored_priv( + &self, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.io.poll_write_vectored(cx, bufs) + } } impl fmt::Debug for UnixStream { -- cgit v1.2.3