From 8dbc3c79379f2243fc04d444239d009c1c610016 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Wed, 21 Oct 2020 14:08:49 -0700 Subject: io: add `AsyncReadExt::read_buf` (#3003) Brings back `read_buf` from 0.2. This will be stabilized as part of 1.0. --- tokio/Cargo.toml | 5 ++- tokio/src/io/util/async_read_ext.rs | 68 +++++++++++++++++++++++++++++++ tokio/src/io/util/async_write_ext.rs | 77 ++++++++++++++++++++++++++++++++++++ tokio/src/io/util/mod.rs | 2 + tokio/src/io/util/read_buf.rs | 72 +++++++++++++++++++++++++++++++++ tokio/src/io/util/read_to_end.rs | 3 +- tokio/src/io/util/write_buf.rs | 55 ++++++++++++++++++++++++++ tokio/tests/io_read.rs | 21 ++++++++++ tokio/tests/io_read_buf.rs | 36 +++++++++++++++++ tokio/tests/io_write_buf.rs | 56 ++++++++++++++++++++++++++ 10 files changed, 392 insertions(+), 3 deletions(-) create mode 100644 tokio/src/io/util/read_buf.rs create mode 100644 tokio/src/io/util/write_buf.rs create mode 100644 tokio/tests/io_read_buf.rs create mode 100644 tokio/tests/io_write_buf.rs diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index e19b8c91..0ebb08d2 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -45,7 +45,7 @@ full = [ ] fs = [] -io-util = ["memchr"] +io-util = ["memchr", "bytes"] # stdin, stdout, stderr io-std = [] macros = ["tokio-macros"] @@ -58,6 +58,7 @@ net = [ "mio/uds", ] process = [ + "bytes", "lazy_static", "libc", "mio/os-poll", @@ -88,10 +89,10 @@ time = [] [dependencies] tokio-macros = { version = "0.3.0", path = "../tokio-macros", optional = true } -bytes = "0.5.0" pin-project-lite = "0.1.1" # Everything else is optional... +bytes = { version = "0.6.0", optional = true } fnv = { version = "1.0.6", optional = true } futures-core = { version = "0.3.0", optional = true } lazy_static = { version = "1.0.2", optional = true } diff --git a/tokio/src/io/util/async_read_ext.rs b/tokio/src/io/util/async_read_ext.rs index d631bd7e..0ab66c28 100644 --- a/tokio/src/io/util/async_read_ext.rs +++ b/tokio/src/io/util/async_read_ext.rs @@ -1,5 +1,6 @@ use crate::io::util::chain::{chain, Chain}; use crate::io::util::read::{read, Read}; +use crate::io::util::read_buf::{read_buf, ReadBuf}; use crate::io::util::read_exact::{read_exact, ReadExact}; use crate::io::util::read_int::{ ReadI128, ReadI128Le, ReadI16, ReadI16Le, ReadI32, ReadI32Le, ReadI64, ReadI64Le, ReadI8, @@ -12,6 +13,8 @@ use crate::io::util::read_to_string::{read_to_string, ReadToString}; use crate::io::util::take::{take, Take}; use crate::io::AsyncRead; +use bytes::BufMut; + cfg_io_util! { /// Defines numeric reader macro_rules! read_impl { @@ -163,6 +166,71 @@ cfg_io_util! { read(self, buf) } + /// Pulls some bytes from this source into the specified buffer, + /// advancing the buffer's internal cursor. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_buf(&mut self, buf: &mut B) -> io::Result; + /// ``` + /// + /// Usually, only a single `read` syscall is issued, even if there is + /// more space in the supplied buffer. + /// + /// This function does not provide any guarantees about whether it + /// completes immediately or asynchronously + /// + /// # Return + /// + /// On a successful read, the number of read bytes is returned. If the + /// supplied buffer is not empty and the function returns `Ok(0)` then + /// the source as reached an "end-of-file" event. + /// + /// # Errors + /// + /// If this function encounters any form of I/O or other error, an error + /// variant will be returned. If an error is returned then it must be + /// guaranteed that no bytes were read. + /// + /// # Examples + /// + /// [`File`] implements `Read` and [`BytesMut`] implements [`BufMut`]: + /// + /// [`File`]: crate::fs::File + /// [`BytesMut`]: bytes::BytesMut + /// [`BufMut`]: bytes::BufMut + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use bytes::BytesMut; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt").await?; + /// let mut buffer = BytesMut::with_capacity(10); + /// + /// assert!(buffer.is_empty()); + /// + /// // read up to 10 bytes, note that the return value is not needed + /// // to access the data that was read as `buffer`'s internal + /// // cursor is updated. + /// f.read_buf(&mut buffer).await?; + /// + /// println!("The bytes: {:?}", &buffer[..]); + /// Ok(()) + /// } + /// ``` + fn read_buf<'a, B>(&'a mut self, buf: &'a mut B) -> ReadBuf<'a, Self, B> + where + Self: Sized + Unpin, + B: BufMut, + { + read_buf(self, buf) + } + /// Reads the exact number of bytes required to fill `buf`. /// /// Equivalent to: diff --git a/tokio/src/io/util/async_write_ext.rs b/tokio/src/io/util/async_write_ext.rs index 5c6187b7..e6ef5b20 100644 --- a/tokio/src/io/util/async_write_ext.rs +++ b/tokio/src/io/util/async_write_ext.rs @@ -2,6 +2,7 @@ use crate::io::util::flush::{flush, Flush}; use crate::io::util::shutdown::{shutdown, Shutdown}; use crate::io::util::write::{write, Write}; use crate::io::util::write_all::{write_all, WriteAll}; +use crate::io::util::write_buf::{write_buf, WriteBuf}; use crate::io::util::write_int::{ WriteI128, WriteI128Le, WriteI16, WriteI16Le, WriteI32, WriteI32Le, WriteI64, WriteI64Le, WriteI8, @@ -12,6 +13,8 @@ use crate::io::util::write_int::{ }; use crate::io::AsyncWrite; +use bytes::Buf; + cfg_io_util! { /// Defines numeric writer macro_rules! write_impl { @@ -116,6 +119,80 @@ cfg_io_util! { write(self, src) } + + /// Writes a buffer into this writer, advancing the buffer's internal + /// cursor. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_buf(&mut self, buf: &mut B) -> io::Result; + /// ``` + /// + /// This function will attempt to write the entire contents of `buf`, but + /// the entire write may not succeed, or the write may also generate an + /// error. After the operation completes, the buffer's + /// internal cursor is advanced by the number of bytes written. A + /// subsequent call to `write_buf` using the **same** `buf` value will + /// resume from the point that the first call to `write_buf` completed. + /// A call to `write_buf` represents *at most one* attempt to write to any + /// wrapped object. + /// + /// # Return + /// + /// If the return value is `Ok(n)` then it must be guaranteed that `n <= + /// buf.len()`. A return value of `0` typically means that the + /// underlying object is no longer able to accept bytes and will likely + /// not be able to in the future as well, or that the buffer provided is + /// empty. + /// + /// # Errors + /// + /// Each call to `write` may generate an I/O error indicating that the + /// operation could not be completed. If an error is returned then no bytes + /// in the buffer were written to this writer. + /// + /// It is **not** considered an error if the entire buffer could not be + /// written to this writer. + /// + /// # Examples + /// + /// [`File`] implements `Read` and [`Cursor<&[u8]>`] implements [`Buf`]: + /// + /// [`File`]: crate::fs::File + /// [`Buf`]: bytes::Buf + /// + /// ```no_run + /// use tokio::io::{self, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// use bytes::Buf; + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// let mut buffer = Cursor::new(b"data to write"); + /// + /// // Loop until the entire contents of the buffer are written to + /// // the file. + /// while buffer.has_remaining() { + /// // Writes some prefix of the byte string, not necessarily + /// // all of it. + /// file.write_buf(&mut buffer).await?; + /// } + /// + /// Ok(()) + /// } + /// ``` + fn write_buf<'a, B>(&'a mut self, src: &'a mut B) -> WriteBuf<'a, Self, B> + where + Self: Sized + Unpin, + B: Buf, + { + write_buf(self, src) + } + /// Attempts to write an entire buffer into this writer. /// /// Equivalent to: diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs index c945be0d..e75ea034 100644 --- a/tokio/src/io/util/mod.rs +++ b/tokio/src/io/util/mod.rs @@ -42,6 +42,7 @@ cfg_io_util! { pub use mem::{duplex, DuplexStream}; mod read; + mod read_buf; mod read_exact; mod read_int; mod read_line; @@ -70,6 +71,7 @@ cfg_io_util! { mod write; mod write_all; + mod write_buf; mod write_int; diff --git a/tokio/src/io/util/read_buf.rs b/tokio/src/io/util/read_buf.rs new file mode 100644 index 00000000..696deefd --- /dev/null +++ b/tokio/src/io/util/read_buf.rs @@ -0,0 +1,72 @@ +use crate::io::AsyncRead; + +use bytes::BufMut; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) fn read_buf<'a, R, B>(reader: &'a mut R, buf: &'a mut B) -> ReadBuf<'a, R, B> +where + R: AsyncRead + Unpin, + B: BufMut, +{ + ReadBuf { + reader, + buf, + _pin: PhantomPinned, + } +} + +pin_project! { + /// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf). + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadBuf<'a, R, B> { + reader: &'a mut R, + buf: &'a mut B, + #[pin] + _pin: PhantomPinned, + } +} + +impl Future for ReadBuf<'_, R, B> +where + R: AsyncRead + Unpin, + B: BufMut, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use crate::io::ReadBuf; + use std::mem::MaybeUninit; + + let me = self.project(); + + if !me.buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = me.buf.bytes_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) + } +} diff --git a/tokio/src/io/util/read_to_end.rs b/tokio/src/io/util/read_to_end.rs index f4fbe631..a9746259 100644 --- a/tokio/src/io/util/read_to_end.rs +++ b/tokio/src/io/util/read_to_end.rs @@ -98,7 +98,8 @@ fn reserve(buf: &mut Vec, bytes: usize) { /// Returns the unused capacity of the provided vector. fn get_unused_capacity(buf: &mut Vec) -> &mut [MaybeUninit] { - bytes::BufMut::bytes_mut(buf) + let uninit = bytes::BufMut::bytes_mut(buf); + unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit]) } } impl Future for ReadToEnd<'_, A> diff --git a/tokio/src/io/util/write_buf.rs b/tokio/src/io/util/write_buf.rs new file mode 100644 index 00000000..1310e5c1 --- /dev/null +++ b/tokio/src/io/util/write_buf.rs @@ -0,0 +1,55 @@ +use crate::io::AsyncWrite; + +use bytes::Buf; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A future to write some of the buffer to an `AsyncWrite`. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct WriteBuf<'a, W, B> { + writer: &'a mut W, + buf: &'a mut B, + #[pin] + _pin: PhantomPinned, + } +} + +/// Tries to write some bytes from the given `buf` to the writer in an +/// asynchronous manner, returning a future. +pub(crate) fn write_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteBuf<'a, W, B> +where + W: AsyncWrite + Unpin, + B: Buf, +{ + WriteBuf { + writer, + buf, + _pin: PhantomPinned, + } +} + +impl Future for WriteBuf<'_, W, B> +where + W: AsyncWrite + Unpin, + B: Buf, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + + if !me.buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.bytes()))?; + me.buf.advance(n); + Poll::Ready(Ok(n)) + } +} diff --git a/tokio/tests/io_read.rs b/tokio/tests/io_read.rs index 29d7d6d7..cb1aa705 100644 --- a/tokio/tests/io_read.rs +++ b/tokio/tests/io_read.rs @@ -36,3 +36,24 @@ async fn read() { assert_eq!(n, 11); assert_eq!(buf[..], b"hello world"[..]); } + +struct BadAsyncRead; + +impl AsyncRead for BadAsyncRead { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + *buf = ReadBuf::new(Box::leak(vec![0; buf.capacity()].into_boxed_slice())); + buf.advance(buf.capacity()); + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +#[should_panic] +async fn read_buf_bad_async_read() { + let mut buf = Vec::with_capacity(10); + BadAsyncRead.read_buf(&mut buf).await.unwrap(); +} diff --git a/tokio/tests/io_read_buf.rs b/tokio/tests/io_read_buf.rs new file mode 100644 index 00000000..0328168d --- /dev/null +++ b/tokio/tests/io_read_buf.rs @@ -0,0 +1,36 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; +use tokio_test::assert_ok; + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[tokio::test] +async fn read_buf() { + struct Rd { + cnt: usize, + } + + impl AsyncRead for Rd { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.cnt += 1; + buf.put_slice(b"hello world"); + Poll::Ready(Ok(())) + } + } + + let mut buf = vec![]; + let mut rd = Rd { cnt: 0 }; + + let n = assert_ok!(rd.read_buf(&mut buf).await); + assert_eq!(1, rd.cnt); + assert_eq!(n, 11); + assert_eq!(buf[..], b"hello world"[..]); +} diff --git a/tokio/tests/io_write_buf.rs b/tokio/tests/io_write_buf.rs new file mode 100644 index 00000000..9ae655b6 --- /dev/null +++ b/tokio/tests/io_write_buf.rs @@ -0,0 +1,56 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio_test::assert_ok; + +use bytes::BytesMut; +use std::cmp; +use std::io::{self, Cursor}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[tokio::test] +async fn write_all() { + struct Wr { + buf: BytesMut, + cnt: usize, + } + + impl AsyncWrite for Wr { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + assert_eq!(self.cnt, 0); + + let n = cmp::min(4, buf.len()); + let buf = &buf[0..n]; + + self.cnt += 1; + self.buf.extend(buf); + Ok(buf.len()).into() + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + } + + let mut wr = Wr { + buf: BytesMut::with_capacity(64), + cnt: 0, + }; + + let mut buf = Cursor::new(&b"hello world"[..]); + + assert_ok!(wr.write_buf(&mut buf).await); + assert_eq!(wr.buf, b"hell"[..]); + assert_eq!(wr.cnt, 1); + assert_eq!(buf.position(), 4); +} -- cgit v1.2.3