summaryrefslogtreecommitdiffstats
path: root/tokio/src/io/async_read.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tokio/src/io/async_read.rs')
-rw-r--r--tokio/src/io/async_read.rs125
1 files changed, 40 insertions, 85 deletions
diff --git a/tokio/src/io/async_read.rs b/tokio/src/io/async_read.rs
index 1aef4150..d341b63d 100644
--- a/tokio/src/io/async_read.rs
+++ b/tokio/src/io/async_read.rs
@@ -1,6 +1,6 @@
+use super::ReadBuf;
use bytes::BufMut;
use std::io;
-use std::mem::MaybeUninit;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -41,47 +41,6 @@ use std::task::{Context, Poll};
/// [`Read::read`]: std::io::Read::read
/// [`AsyncReadExt`]: crate::io::AsyncReadExt
pub trait AsyncRead {
- /// Prepares an uninitialized buffer to be safe to pass to `read`. Returns
- /// `true` if the supplied buffer was zeroed out.
- ///
- /// While it would be highly unusual, implementations of [`io::Read`] are
- /// able to read data from the buffer passed as an argument. Because of
- /// this, the buffer passed to [`io::Read`] must be initialized memory. In
- /// situations where large numbers of buffers are used, constantly having to
- /// zero out buffers can be expensive.
- ///
- /// This function does any necessary work to prepare an uninitialized buffer
- /// to be safe to pass to `read`. If `read` guarantees to never attempt to
- /// read data out of the supplied buffer, then `prepare_uninitialized_buffer`
- /// doesn't need to do any work.
- ///
- /// If this function returns `true`, then the memory has been zeroed out.
- /// This allows implementations of `AsyncRead` which are composed of
- /// multiple subimplementations to efficiently implement
- /// `prepare_uninitialized_buffer`.
- ///
- /// This function isn't actually `unsafe` to call but `unsafe` to implement.
- /// The implementer must ensure that either the whole `buf` has been zeroed
- /// or `poll_read_buf()` overwrites the buffer without reading it and returns
- /// correct value.
- ///
- /// This function is called from [`poll_read_buf`].
- ///
- /// # Safety
- ///
- /// Implementations that return `false` must never read from data slices
- /// that they did not write to.
- ///
- /// [`io::Read`]: std::io::Read
- /// [`poll_read_buf`]: method@Self::poll_read_buf
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- for x in buf {
- *x = MaybeUninit::new(0);
- }
-
- true
- }
-
/// Attempts to read from the `AsyncRead` into `buf`.
///
/// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
@@ -93,8 +52,8 @@ pub trait AsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>>;
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>>;
/// Pulls some bytes from this source into the specified `BufMut`, returning
/// how many bytes were read.
@@ -114,37 +73,26 @@ pub trait AsyncRead {
return Poll::Ready(Ok(0));
}
- unsafe {
- let n = {
- let b = buf.bytes_mut();
-
- self.prepare_uninitialized_buffer(b);
-
- // Convert to `&mut [u8]`
- let b = &mut *(b as *mut [MaybeUninit<u8>] as *mut [u8]);
+ let mut b = ReadBuf::uninit(buf.bytes_mut());
- let n = ready!(self.poll_read(cx, b))?;
- assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold");
- n
- };
+ ready!(self.poll_read(cx, &mut b))?;
+ let n = b.filled().len();
+ // Safety: we can assume `n` bytes were read, since they are in`filled`.
+ unsafe {
buf.advance_mut(n);
- Poll::Ready(Ok(n))
}
+ Poll::Ready(Ok(n))
}
}
macro_rules! deref_async_read {
() => {
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- (**self).prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
@@ -163,43 +111,50 @@ where
P: DerefMut + Unpin,
P::Target: AsyncRead,
{
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- (**self).prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_read(cx, buf)
}
}
impl AsyncRead for &[u8] {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
- self: Pin<&mut Self>,
+ mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(io::Read::read(self.get_mut(), buf))
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let amt = std::cmp::min(self.len(), buf.remaining());
+ let (a, b) = self.split_at(amt);
+ buf.append(a);
+ *self = b;
+ Poll::Ready(Ok(()))
}
}
impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
- self: Pin<&mut Self>,
+ mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(io::Read::read(self.get_mut(), buf))
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let pos = self.position();
+ let slice: &[u8] = (*self).get_ref().as_ref();
+
+ // The position could technically be out of bounds, so don't panic...
+ if pos > slice.len() as u64 {
+ return Poll::Ready(Ok(()));
+ }
+
+ let start = pos as usize;
+ let amt = std::cmp::min(slice.len() - start, buf.remaining());
+ // Add won't overflow because of pos check above.
+ let end = start + amt;
+ buf.append(&slice[start..end]);
+ self.set_position(end as u64);
+
+ Poll::Ready(Ok(()))
}
}