summaryrefslogtreecommitdiffstats
path: root/tokio/src/net
diff options
context:
space:
mode:
authorSean McArthur <sean@seanmonstar.com>2020-08-13 20:15:01 -0700
committerGitHub <noreply@github.com>2020-08-13 20:15:01 -0700
commitc393236dfd12c13e82badd631d3a3a90481c6f95 (patch)
tree47e7e70b7a58fb968870d5d44e95f6c45192e114 /tokio/src/net
parent71da06097bf9aa851ebdde79d7b01a3e38174db9 (diff)
io: change AsyncRead to use a ReadBuf (#2758)
Works towards #2716. Changes the argument to `AsyncRead::poll_read` to take a `ReadBuf` struct that safely manages writes to uninitialized memory.
Diffstat (limited to 'tokio/src/net')
-rw-r--r--tokio/src/net/tcp/split.rs11
-rw-r--r--tokio/src/net/tcp/split_owned.rs11
-rw-r--r--tokio/src/net/tcp/stream.rs31
-rw-r--r--tokio/src/net/unix/split.rs11
-rw-r--r--tokio/src/net/unix/split_owned.rs11
-rw-r--r--tokio/src/net/unix/stream.rs31
6 files changed, 50 insertions, 56 deletions
diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs
index 0c1e359f..9d99d7bd 100644
--- a/tokio/src/net/tcp/split.rs
+++ b/tokio/src/net/tcp/split.rs
@@ -9,12 +9,11 @@
//! level.
use crate::future::poll_fn;
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::TcpStream;
use bytes::Buf;
use std::io;
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -131,15 +130,11 @@ impl ReadHalf<'_> {
}
impl AsyncRead for ReadHalf<'_> {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.0.poll_read_priv(cx, buf)
}
}
diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs
index 6c2b9e69..87be6efd 100644
--- a/tokio/src/net/tcp/split_owned.rs
+++ b/tokio/src/net/tcp/split_owned.rs
@@ -9,12 +9,11 @@
//! level.
use crate::future::poll_fn;
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::TcpStream;
use bytes::Buf;
use std::error::Error;
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::pin::Pin;
use std::sync::Arc;
@@ -186,15 +185,11 @@ impl OwnedReadHalf {
}
impl AsyncRead for OwnedReadHalf {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.inner.poll_read_priv(cx, buf)
}
}
diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs
index 02b52627..e624fb9d 100644
--- a/tokio/src/net/tcp/stream.rs
+++ b/tokio/src/net/tcp/stream.rs
@@ -1,5 +1,5 @@
use crate::future::poll_fn;
-use crate::io::{AsyncRead, AsyncWrite, PollEvented};
+use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf};
use crate::net::tcp::split::{split, ReadHalf, WriteHalf};
use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
use crate::net::ToSocketAddrs;
@@ -9,7 +9,6 @@ use iovec::IoVec;
use std::convert::TryFrom;
use std::fmt;
use std::io::{self, Read, Write};
-use std::mem::MaybeUninit;
use std::net::{self, Shutdown, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -702,16 +701,28 @@ impl TcpStream {
pub(crate) fn poll_read_priv(
&self,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
- match self.io.get_ref().read(buf) {
+ // Safety: `TcpStream::read` will not peak at the maybe uinitialized bytes.
+ let b =
+ unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };
+ match self.io.get_ref().read(b) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_read_ready(cx, mio::Ready::readable())?;
Poll::Pending
}
- x => Poll::Ready(x),
+ Ok(n) => {
+ // Safety: We trust `TcpStream::read` to have filled up `n` bytes
+ // in the buffer.
+ unsafe {
+ buf.assume_init(n);
+ }
+ buf.add_filled(n);
+ Poll::Ready(Ok(()))
+ }
+ Err(e) => Poll::Ready(Err(e)),
}
}
@@ -864,15 +875,11 @@ impl TryFrom<net::TcpStream> for TcpStream {
// ===== impl Read / Write =====
impl AsyncRead for TcpStream {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.poll_read_priv(cx, buf)
}
}
diff --git a/tokio/src/net/unix/split.rs b/tokio/src/net/unix/split.rs
index 4fd85774..460bbc19 100644
--- a/tokio/src/net/unix/split.rs
+++ b/tokio/src/net/unix/split.rs
@@ -8,11 +8,10 @@
//! split has no associated overhead and enforces all invariants at the type
//! level.
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::UnixStream;
use std::io;
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -51,15 +50,11 @@ pub(crate) fn split(stream: &mut UnixStream) -> (ReadHalf<'_>, WriteHalf<'_>) {
}
impl AsyncRead for ReadHalf<'_> {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.0.poll_read_priv(cx, buf)
}
}
diff --git a/tokio/src/net/unix/split_owned.rs b/tokio/src/net/unix/split_owned.rs
index eb35304b..ab233072 100644
--- a/tokio/src/net/unix/split_owned.rs
+++ b/tokio/src/net/unix/split_owned.rs
@@ -8,11 +8,10 @@
//! split has no associated overhead and enforces all invariants at the type
//! level.
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::UnixStream;
use std::error::Error;
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::pin::Pin;
use std::sync::Arc;
@@ -109,15 +108,11 @@ impl OwnedReadHalf {
}
impl AsyncRead for OwnedReadHalf {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.inner.poll_read_priv(cx, buf)
}
}
diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs
index 5fe242d0..559fe02a 100644
--- a/tokio/src/net/unix/stream.rs
+++ b/tokio/src/net/unix/stream.rs
@@ -1,5 +1,5 @@
use crate::future::poll_fn;
-use crate::io::{AsyncRead, AsyncWrite, PollEvented};
+use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf};
use crate::net::unix::split::{split, ReadHalf, WriteHalf};
use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
use crate::net::unix::ucred::{self, UCred};
@@ -7,7 +7,6 @@ use crate::net::unix::ucred::{self, UCred};
use std::convert::TryFrom;
use std::fmt;
use std::io::{self, Read, Write};
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{self, SocketAddr};
@@ -167,15 +166,11 @@ impl TryFrom<net::UnixStream> for UnixStream {
}
impl AsyncRead for UnixStream {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.poll_read_priv(cx, buf)
}
}
@@ -214,16 +209,28 @@ impl UnixStream {
pub(crate) fn poll_read_priv(
&self,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
- match self.io.get_ref().read(buf) {
+ // Safety: `UnixStream::read` will not peak at the maybe uinitialized bytes.
+ let b =
+ unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };
+ match self.io.get_ref().read(b) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_read_ready(cx, mio::Ready::readable())?;
Poll::Pending
}
- x => Poll::Ready(x),
+ Ok(n) => {
+ // Safety: We trust `UnixStream::read` to have filled up `n` bytes
+ // in the buffer.
+ unsafe {
+ buf.assume_init(n);
+ }
+ buf.add_filled(n);
+ Poll::Ready(Ok(()))
+ }
+ Err(e) => Poll::Ready(Err(e)),
}
}