summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSean McArthur <sean@seanmonstar.com>2020-08-13 20:15:01 -0700
committerGitHub <noreply@github.com>2020-08-13 20:15:01 -0700
commitc393236dfd12c13e82badd631d3a3a90481c6f95 (patch)
tree47e7e70b7a58fb968870d5d44e95f6c45192e114
parent71da06097bf9aa851ebdde79d7b01a3e38174db9 (diff)
io: change AsyncRead to use a ReadBuf (#2758)
Works towards #2716. Changes the argument to `AsyncRead::poll_read` to take a `ReadBuf` struct that safely manages writes to uninitialized memory.
-rw-r--r--tokio-test/src/io.rs43
-rw-r--r--tokio-util/src/compat.rs26
-rw-r--r--tokio-util/tests/framed.rs4
-rw-r--r--tokio-util/tests/framed_read.rs18
-rw-r--r--tokio-util/tests/length_delimited.rs14
-rw-r--r--tokio/src/fs/file.rs19
-rw-r--r--tokio/src/io/async_read.rs125
-rw-r--r--tokio/src/io/blocking.rs24
-rw-r--r--tokio/src/io/mod.rs3
-rw-r--r--tokio/src/io/poll_evented.rs14
-rw-r--r--tokio/src/io/read_buf.rs253
-rw-r--r--tokio/src/io/split.rs6
-rw-r--r--tokio/src/io/stdin.rs11
-rw-r--r--tokio/src/io/util/buf_reader.rs50
-rw-r--r--tokio/src/io/util/buf_stream.rs12
-rw-r--r--tokio/src/io/util/buf_writer.rs12
-rw-r--r--tokio/src/io/util/chain.rs24
-rw-r--r--tokio/src/io/util/copy.rs6
-rw-r--r--tokio/src/io/util/empty.rs11
-rw-r--r--tokio/src/io/util/mem.rs19
-rw-r--r--tokio/src/io/util/read.rs6
-rw-r--r--tokio/src/io/util/read_exact.rs22
-rw-r--r--tokio/src/io/util/read_int.rs33
-rw-r--r--tokio/src/io/util/read_to_end.rs85
-rw-r--r--tokio/src/io/util/read_to_string.rs5
-rw-r--r--tokio/src/io/util/repeat.rs16
-rw-r--r--tokio/src/io/util/stream_reader.rs20
-rw-r--r--tokio/src/io/util/take.rs29
-rw-r--r--tokio/src/net/tcp/split.rs11
-rw-r--r--tokio/src/net/tcp/split_owned.rs11
-rw-r--r--tokio/src/net/tcp/stream.rs31
-rw-r--r--tokio/src/net/unix/split.rs11
-rw-r--r--tokio/src/net/unix/split_owned.rs11
-rw-r--r--tokio/src/net/unix/stream.rs31
-rw-r--r--tokio/src/process/mod.rs20
-rw-r--r--tokio/src/signal/unix.rs14
-rw-r--r--tokio/tests/io_async_read.rs67
-rw-r--r--tokio/tests/io_copy.rs12
-rw-r--r--tokio/tests/io_read.rs32
-rw-r--r--tokio/tests/io_split.rs9
40 files changed, 626 insertions, 544 deletions
diff --git a/tokio-test/src/io.rs b/tokio-test/src/io.rs
index 26ef57e4..f1ce77aa 100644
--- a/tokio-test/src/io.rs
+++ b/tokio-test/src/io.rs
@@ -18,7 +18,7 @@
//! [`AsyncRead`]: tokio::io::AsyncRead
//! [`AsyncWrite`]: tokio::io::AsyncWrite
-use tokio::io::{AsyncRead, AsyncWrite};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc;
use tokio::time::{self, Delay, Duration, Instant};
@@ -204,20 +204,19 @@ impl Inner {
self.rx.poll_recv(cx)
}
- fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
+ fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
match self.action() {
Some(&mut Action::Read(ref mut data)) => {
// Figure out how much to copy
- let n = cmp::min(dst.len(), data.len());
+ let n = cmp::min(dst.remaining(), data.len());
// Copy the data into the `dst` slice
- (&mut dst[..n]).copy_from_slice(&data[..n]);
+ dst.append(&data[..n]);
// Drain the data from the source
data.drain(..n);
- // Return the number of bytes read
- Ok(n)
+ Ok(())
}
Some(&mut Action::ReadError(ref mut err)) => {
// As the
@@ -229,7 +228,7 @@ impl Inner {
// Either waiting or expecting a write
Err(io::ErrorKind::WouldBlock.into())
}
- None => Ok(0),
+ None => Ok(()),
}
}
@@ -348,8 +347,8 @@ impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
@@ -358,6 +357,9 @@ impl AsyncRead for Mock {
// If a sleep is set, it has already fired
self.inner.sleep = None;
+ // Capture 'filled' to monitor if it changed
+ let filled = buf.filled().len();
+
match self.inner.read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
@@ -368,19 +370,22 @@ impl AsyncRead for Mock {
return Poll::Pending;
}
}
- Ok(0) => {
- // TODO: Extract
- match ready!(self.inner.poll_action(cx)) {
- Some(action) => {
- self.inner.actions.push_back(action);
- continue;
- }
- None => {
- return Poll::Ready(Ok(0));
+ Ok(()) => {
+ if buf.filled().len() == filled {
+ match ready!(self.inner.poll_action(cx)) {
+ Some(action) => {
+ self.inner.actions.push_back(action);
+ continue;
+ }
+ None => {
+ return Poll::Ready(Ok(()));
+ }
}
+ } else {
+ return Poll::Ready(Ok(()));
}
}
- ret => return Poll::Ready(ret),
+ Err(e) => return Poll::Ready(Err(e)),
}
}
}
diff --git a/tokio-util/src/compat.rs b/tokio-util/src/compat.rs
index 769e30c2..34120d43 100644
--- a/tokio-util/src/compat.rs
+++ b/tokio-util/src/compat.rs
@@ -1,5 +1,6 @@
//! Compatibility between the `tokio::io` and `futures-io` versions of the
//! `AsyncRead` and `AsyncWrite` traits.
+use futures_core::ready;
use pin_project_lite::pin_project;
use std::io;
use std::pin::Pin;
@@ -107,9 +108,18 @@ where
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- futures_io::AsyncRead::poll_read(self.project().inner, cx, buf)
+ buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ // We can't trust the inner type to not peak at the bytes,
+ // so we must defensively initialize the buffer.
+ let slice = buf.initialize_unfilled();
+ let n = ready!(futures_io::AsyncRead::poll_read(
+ self.project().inner,
+ cx,
+ slice
+ ))?;
+ buf.add_filled(n);
+ Poll::Ready(Ok(()))
}
}
@@ -120,9 +130,15 @@ where
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
+ slice: &mut [u8],
) -> Poll<io::Result<usize>> {
- tokio::io::AsyncRead::poll_read(self.project().inner, cx, buf)
+ let mut buf = tokio::io::ReadBuf::new(slice);
+ ready!(tokio::io::AsyncRead::poll_read(
+ self.project().inner,
+ cx,
+ &mut buf
+ ))?;
+ Poll::Ready(Ok(buf.filled().len()))
}
}
diff --git a/tokio-util/tests/framed.rs b/tokio-util/tests/framed.rs
index d7ee3ef5..4c5f8418 100644
--- a/tokio-util/tests/framed.rs
+++ b/tokio-util/tests/framed.rs
@@ -55,8 +55,8 @@ impl AsyncRead for DontReadIntoThis {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- _buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ _buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
unreachable!()
}
}
diff --git a/tokio-util/tests/framed_read.rs b/tokio-util/tests/framed_read.rs
index 27bb298a..da38c432 100644
--- a/tokio-util/tests/framed_read.rs
+++ b/tokio-util/tests/framed_read.rs
@@ -1,6 +1,6 @@
#![warn(rust_2018_idioms)]
-use tokio::io::AsyncRead;
+use tokio::io::{AsyncRead, ReadBuf};
use tokio_test::assert_ready;
use tokio_test::task;
use tokio_util::codec::{Decoder, FramedRead};
@@ -264,19 +264,19 @@ impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
use io::ErrorKind::WouldBlock;
match self.calls.pop_front() {
Some(Ok(data)) => {
- debug_assert!(buf.len() >= data.len());
- buf[..data.len()].copy_from_slice(&data[..]);
- Ready(Ok(data.len()))
+ debug_assert!(buf.remaining() >= data.len());
+ buf.append(&data);
+ Ready(Ok(()))
}
Some(Err(ref e)) if e.kind() == WouldBlock => Pending,
Some(Err(e)) => Ready(Err(e)),
- None => Ready(Ok(0)),
+ None => Ready(Ok(())),
}
}
}
@@ -288,8 +288,8 @@ impl AsyncRead for Slice<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
diff --git a/tokio-util/tests/length_delimited.rs b/tokio-util/tests/length_delimited.rs
index 734cd834..9f615412 100644
--- a/tokio-util/tests/length_delimited.rs
+++ b/tokio-util/tests/length_delimited.rs
@@ -1,6 +1,6 @@
#![warn(rust_2018_idioms)]
-use tokio::io::{AsyncRead, AsyncWrite};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_test::task;
use tokio_test::{
assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok,
@@ -707,18 +707,18 @@ impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- dst: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ dst: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
match self.calls.pop_front() {
Some(Ready(Ok(Op::Data(data)))) => {
- debug_assert!(dst.len() >= data.len());
- dst[..data.len()].copy_from_slice(&data[..]);
- Ready(Ok(data.len()))
+ debug_assert!(dst.remaining() >= data.len());
+ dst.append(&data);
+ Ready(Ok(()))
}
Some(Ready(Ok(_))) => panic!(),
Some(Ready(Err(e))) => Ready(Err(e)),
Some(Pending) => Pending,
- None => Ready(Ok(0)),
+ None => Ready(Ok(())),
}
}
}
diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs
index c44196b3..2c36806d 100644
--- a/tokio/src/fs/file.rs
+++ b/tokio/src/fs/file.rs
@@ -5,7 +5,7 @@
use self::State::*;
use crate::fs::{asyncify, sys};
use crate::io::blocking::Buf;
-use crate::io::{AsyncRead, AsyncSeek, AsyncWrite};
+use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use std::fmt;
use std::fs::{Metadata, Permissions};
@@ -537,25 +537,20 @@ impl File {
}
impl AsyncRead for File {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668
- false
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- dst: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ dst: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
if !buf.is_empty() {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
*buf_cell = Some(buf);
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
buf.ensure_capacity_for(dst);
@@ -571,9 +566,9 @@ impl AsyncRead for File {
match op {
Operation::Read(Ok(_)) => {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
self.state = Idle(Some(buf));
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
Operation::Read(Err(e)) => {
assert!(buf.is_empty());
diff --git a/tokio/src/io/async_read.rs b/tokio/src/io/async_read.rs
index 1aef4150..d341b63d 100644
--- a/tokio/src/io/async_read.rs
+++ b/tokio/src/io/async_read.rs
@@ -1,6 +1,6 @@
+use super::ReadBuf;
use bytes::BufMut;
use std::io;
-use std::mem::MaybeUninit;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -41,47 +41,6 @@ use std::task::{Context, Poll};
/// [`Read::read`]: std::io::Read::read
/// [`AsyncReadExt`]: crate::io::AsyncReadExt
pub trait AsyncRead {
- /// Prepares an uninitialized buffer to be safe to pass to `read`. Returns
- /// `true` if the supplied buffer was zeroed out.
- ///
- /// While it would be highly unusual, implementations of [`io::Read`] are
- /// able to read data from the buffer passed as an argument. Because of
- /// this, the buffer passed to [`io::Read`] must be initialized memory. In
- /// situations where large numbers of buffers are used, constantly having to
- /// zero out buffers can be expensive.
- ///
- /// This function does any necessary work to prepare an uninitialized buffer
- /// to be safe to pass to `read`. If `read` guarantees to never attempt to
- /// read data out of the supplied buffer, then `prepare_uninitialized_buffer`
- /// doesn't need to do any work.
- ///
- /// If this function returns `true`, then the memory has been zeroed out.
- /// This allows implementations of `AsyncRead` which are composed of
- /// multiple subimplementations to efficiently implement
- /// `prepare_uninitialized_buffer`.
- ///
- /// This function isn't actually `unsafe` to call but `unsafe` to implement.
- /// The implementer must ensure that either the whole `buf` has been zeroed
- /// or `poll_read_buf()` overwrites the buffer without reading it and returns
- /// correct value.
- ///
- /// This function is called from [`poll_read_buf`].
- ///
- /// # Safety
- ///
- /// Implementations that return `false` must never read from data slices
- /// that they did not write to.
- ///
- /// [`io::Read`]: std::io::Read
- /// [`poll_read_buf`]: method@Self::poll_read_buf
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- for x in buf {
- *x = MaybeUninit::new(0);
- }
-
- true
- }
-
/// Attempts to read from the `AsyncRead` into `buf`.
///
/// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
@@ -93,8 +52,8 @@ pub trait AsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>>;
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>>;
/// Pulls some bytes from this source into the specified `BufMut`, returning
/// how many bytes were read.
@@ -114,37 +73,26 @@ pub trait AsyncRead {
return Poll::Ready(Ok(0));
}
- unsafe {
- let n = {
- let b = buf.bytes_mut();
-
- self.prepare_uninitialized_buffer(b);
-
- // Convert to `&mut [u8]`
- let b = &mut *(b as *mut [MaybeUninit<u8>] as *mut [u8]);
+ let mut b = ReadBuf::uninit(buf.bytes_mut());
- let n = ready!(self.poll_read(cx, b))?;
- assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold");
- n
- };
+ ready!(self.poll_read(cx, &mut b))?;
+ let n = b.filled().len();
+ // Safety: we can assume `n` bytes were read, since they are in`filled`.
+ unsafe {
buf.advance_mut(n);
- Poll::Ready(Ok(n))
}
+ Poll::Ready(Ok(n))
}
}
macro_rules! deref_async_read {
() => {
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- (**self).prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
@@ -163,43 +111,50 @@ where
P: DerefMut + Unpin,
P::Target: AsyncRead,
{
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- (**self).prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_read(cx, buf)
}
}
impl AsyncRead for &[u8] {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
- self: Pin<&mut Self>,
+ mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(io::Read::read(self.get_mut(), buf))
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let amt = std::cmp::min(self.len(), buf.remaining());
+ let (a, b) = self.split_at(amt);
+ buf.append(a);
+ *self = b;
+ Poll::Ready(Ok(()))
}
}
impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
- self: Pin<&mut Self>,
+ mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(io::Read::read(self.get_mut(), buf))
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let pos = self.position();
+ let slice: &[u8] = (*self).get_ref().as_ref();
+
+ // The position could technically be out of bounds, so don't panic...
+ if pos > slice.len() as u64 {
+ return Poll::Ready(Ok(()));
+ }
+
+ let start = pos as usize;
+ let amt = std::cmp::min(slice.len() - start, buf.remaining());
+ // Add won't overflow because of pos check above.
+ let end = start + amt;
+ buf.append(&slice[start..end]);
+ self.set_position(end as u64);
+
+ Poll::Ready(Ok(()))
}
}
diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs
index 2491039a..d2265a00 100644
--- a/tokio/src/io/blocking.rs
+++ b/tokio/src/io/blocking.rs
@@ -1,5 +1,5 @@
use crate::io::sys;
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::cmp;
use std::future::Future;
@@ -53,17 +53,17 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- dst: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ dst: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
if !buf.is_empty() {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
*buf_cell = Some(buf);
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
buf.ensure_capacity_for(dst);
@@ -80,9 +80,9 @@ where
match res {
Ok(_) => {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
self.state = Idle(Some(buf));
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
Err(e) => {
assert!(buf.is_empty());
@@ -203,9 +203,9 @@ impl Buf {
self.buf.len() - self.pos
}
- pub(crate) fn copy_to(&mut self, dst: &mut [u8]) -> usize {
- let n = cmp::min(self.len(), dst.len());
- dst[..n].copy_from_slice(&self.bytes()[..n]);
+ pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize {
+ let n = cmp::min(self.len(), dst.remaining());
+ dst.append(&self.bytes()[..n]);
self.pos += n;
if self.pos == self.buf.len() {
@@ -229,10 +229,10 @@ impl Buf {
&self.buf[self.pos..]
}
- pub(crate) fn ensure_capacity_for(&mut self, bytes: &[u8]) {
+ pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) {
assert!(self.is_empty());
- let len = cmp::min(bytes.len(), MAX_BUF);
+ let len = cmp::min(bytes.remaining(), MAX_BUF);
if self.buf.len() < len {
self.buf.reserve(len - self.buf.len());
diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs
index 9e0e0631..c43f0e83 100644
--- a/tokio/src/io/mod.rs
+++ b/tokio/src/io/mod.rs
@@ -196,6 +196,9 @@ pub use self::async_seek::AsyncSeek;
mod async_write;
pub use self::async_write::AsyncWrite;
+mod read_buf;
+pub use self::read_buf::ReadBuf;
+
// Re-export some types from `std::io` so that users don't have to deal
// with conflicts when `use`ing `tokio::io` and `std::io`.
pub use std::io::{Error, ErrorKind, Result, SeekFrom};
diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs
index 5295bd71..785968f4 100644
--- a/tokio/src/io/poll_evented.rs
+++ b/tokio/src/io/poll_evented.rs
@@ -1,5 +1,5 @@
use crate::io::driver::platform;
-use crate::io::{AsyncRead, AsyncWrite, Registration};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf, Registration};
use mio::event::Evented;
use std::fmt;
@@ -384,18 +384,22 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
ready!(self.poll_read_ready(cx, mio::Ready::readable()))?;
- let r = (*self).get_mut().read(buf);
+ // We can't assume the `Read` won't look at the read buffer,
+ // so we have to force initialization here.
+ let r = (*self).get_mut().read(buf.initialize_unfilled());
if is_wouldblock(&r) {
self.clear_read_ready(cx, mio::Ready::readable())?;
return Poll::Pending;
}
- Poll::Ready(r)
+ Poll::Ready(r.map(|n| {
+ buf.add_filled(n);
+ }))
}
}
diff --git a/tokio/src/io/read_buf.rs b/tokio/src/io/read_buf.rs
new file mode 100644
index 00000000..03b5d05c
--- /dev/null
+++ b/tokio/src/io/read_buf.rs
@@ -0,0 +1,253 @@
+// This lint claims ugly casting is somehow safer than transmute, but there's
+// no evidence that is the case. Shush.
+#![allow(clippy::transmute_ptr_to_ptr)]
+
+use std::fmt;
+use std::mem::{self, MaybeUninit};
+
+/// A wrapper around a byte buffer that is incrementally filled and initialized.
+///
+/// This type is a sort of "double cursor". It tracks three regions in the
+/// buffer: a region at the beginning of the buffer that has been logically
+/// filled with data, a region that has been initialized at some point but not
+/// yet logically filled, and a region at the end that is fully uninitialized.
+/// The filled region is guaranteed to be a subset of the initialized region.
+///
+/// In summary, the contents of the buffer can be visualized as:
+///
+/// ```not_rust
+/// [ capacity ]
+/// [ filled | unfilled ]
+/// [ initialized | uninitialized ]
+/// ```
+pub struct ReadBuf<'a> {
+ buf: &'a mut [MaybeUninit<u8>],
+ filled: usize,
+ initialized: usize,
+}
+
+impl<'a> ReadBuf<'a> {
+ /// Creates a new `ReadBuf` from a fully initialized buffer.
+ #[inline]
+ pub fn new(buf: &'a mut [u8]) -> ReadBuf<'a> {
+ let initialized = buf.len();
+ let buf = unsafe { mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buf) };
+ ReadBuf {
+ buf,
+ filled: 0,
+ initialized,
+ }
+ }
+
+ /// Creates a new `ReadBuf` from a fully uninitialized buffer.
+ ///
+ /// Use `assume_init` if part of the buffer is known to be already inintialized.
+ #[inline]
+ pub fn uninit(buf: &'a mut [MaybeUninit<u8>]) -> ReadBuf<'a> {
+ ReadBuf {
+ buf,
+ filled: 0,
+ initialized: 0,
+ }
+ }
+
+ /// Returns the total capacity of the buffer.
+ #[inline]
+ pub fn capacity(&self) -> usize {
+ self.buf.len()
+ }
+
+ /// Returns a shared reference to the filled portion of the buffer.
+ #[inline]
+ pub fn filled(&self) -> &[u8] {
+ let slice = &self.buf[..self.filled];
+ // safety: filled describes how far into the buffer that the
+ // user has filled with bytes, so it's been initialized.
+ // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable.
+ unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) }
+ }
+
+ /// Returns a mutable reference to the filled portion of the buffer.
+ #[inline]
+ pub fn filled_mut(&mut self) -> &mut [u8] {
+ let slice = &mut self.buf[..self.filled];
+ // safety: filled describes how far into the buffer that the
+ // user has filled with bytes, so it's been initialized.
+ // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable.
+ unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) }
+ }
+
+ /// Returns a shared reference to the initialized portion of the buffer.
+ ///
+ /// This includes the filled portion.
+ #[inline]
+ pub fn initialized(&self) -> &[u8] {
+ let slice = &self.buf[..self.initialized];
+ // safety: initialized describes how far into the buffer that the
+ // user has at some point initialized with bytes.
+