diff options
author | Sean McArthur <sean@seanmonstar.com> | 2020-08-13 20:15:01 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-13 20:15:01 -0700 |
commit | c393236dfd12c13e82badd631d3a3a90481c6f95 (patch) | |
tree | 47e7e70b7a58fb968870d5d44e95f6c45192e114 /tokio/src/net | |
parent | 71da06097bf9aa851ebdde79d7b01a3e38174db9 (diff) |
io: change AsyncRead to use a ReadBuf (#2758)
Works towards #2716. Changes the argument to `AsyncRead::poll_read` to
take a `ReadBuf` struct that safely manages writes to uninitialized memory.
Diffstat (limited to 'tokio/src/net')
-rw-r--r-- | tokio/src/net/tcp/split.rs | 11 | ||||
-rw-r--r-- | tokio/src/net/tcp/split_owned.rs | 11 | ||||
-rw-r--r-- | tokio/src/net/tcp/stream.rs | 31 | ||||
-rw-r--r-- | tokio/src/net/unix/split.rs | 11 | ||||
-rw-r--r-- | tokio/src/net/unix/split_owned.rs | 11 | ||||
-rw-r--r-- | tokio/src/net/unix/stream.rs | 31 |
6 files changed, 50 insertions, 56 deletions
diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs index 0c1e359f..9d99d7bd 100644 --- a/tokio/src/net/tcp/split.rs +++ b/tokio/src/net/tcp/split.rs @@ -9,12 +9,11 @@ //! level. use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::TcpStream; use bytes::Buf; use std::io; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; @@ -131,15 +130,11 @@ impl ReadHalf<'_> { } impl AsyncRead for ReadHalf<'_> { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.0.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs index 6c2b9e69..87be6efd 100644 --- a/tokio/src/net/tcp/split_owned.rs +++ b/tokio/src/net/tcp/split_owned.rs @@ -9,12 +9,11 @@ //! level. use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::TcpStream; use bytes::Buf; use std::error::Error; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::sync::Arc; @@ -186,15 +185,11 @@ impl OwnedReadHalf { } impl AsyncRead for OwnedReadHalf { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.inner.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index 02b52627..e624fb9d 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -1,5 +1,5 @@ use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, PollEvented}; +use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf}; use crate::net::tcp::split::{split, ReadHalf, WriteHalf}; use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::ToSocketAddrs; @@ -9,7 +9,6 @@ use iovec::IoVec; use std::convert::TryFrom; use std::fmt; use std::io::{self, Read, Write}; -use std::mem::MaybeUninit; use std::net::{self, Shutdown, SocketAddr}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -702,16 +701,28 @@ impl TcpStream { pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - match self.io.get_ref().read(buf) { + // Safety: `TcpStream::read` will not peak at the maybe uinitialized bytes. + let b = + unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; + match self.io.get_ref().read(b) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.io.clear_read_ready(cx, mio::Ready::readable())?; Poll::Pending } - x => Poll::Ready(x), + Ok(n) => { + // Safety: We trust `TcpStream::read` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.add_filled(n); + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), } } @@ -864,15 +875,11 @@ impl TryFrom<net::TcpStream> for TcpStream { // ===== impl Read / Write ===== impl AsyncRead for TcpStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/unix/split.rs b/tokio/src/net/unix/split.rs index 4fd85774..460bbc19 100644 --- a/tokio/src/net/unix/split.rs +++ b/tokio/src/net/unix/split.rs @@ -8,11 +8,10 @@ //! split has no associated overhead and enforces all invariants at the type //! level. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::UnixStream; use std::io; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; @@ -51,15 +50,11 @@ pub(crate) fn split(stream: &mut UnixStream) -> (ReadHalf<'_>, WriteHalf<'_>) { } impl AsyncRead for ReadHalf<'_> { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.0.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/unix/split_owned.rs b/tokio/src/net/unix/split_owned.rs index eb35304b..ab233072 100644 --- a/tokio/src/net/unix/split_owned.rs +++ b/tokio/src/net/unix/split_owned.rs @@ -8,11 +8,10 @@ //! split has no associated overhead and enforces all invariants at the type //! level. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::UnixStream; use std::error::Error; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::sync::Arc; @@ -109,15 +108,11 @@ impl OwnedReadHalf { } impl AsyncRead for OwnedReadHalf { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.inner.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index 5fe242d0..559fe02a 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -1,5 +1,5 @@ use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, PollEvented}; +use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf}; use crate::net::unix::split::{split, ReadHalf, WriteHalf}; use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::unix::ucred::{self, UCred}; @@ -7,7 +7,6 @@ use crate::net::unix::ucred::{self, UCred}; use std::convert::TryFrom; use std::fmt; use std::io::{self, Read, Write}; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::{self, SocketAddr}; @@ -167,15 +166,11 @@ impl TryFrom<net::UnixStream> for UnixStream { } impl AsyncRead for UnixStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.poll_read_priv(cx, buf) } } @@ -214,16 +209,28 @@ impl UnixStream { pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - match self.io.get_ref().read(buf) { + // Safety: `UnixStream::read` will not peak at the maybe uinitialized bytes. + let b = + unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; + match self.io.get_ref().read(b) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.io.clear_read_ready(cx, mio::Ready::readable())?; Poll::Pending } - x => Poll::Ready(x), + Ok(n) => { + // Safety: We trust `UnixStream::read` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.add_filled(n); + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), } } |