From 37f405bd3b06921598d298b0ba5b9296656454bf Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Tue, 8 Sep 2020 09:12:32 +0200 Subject: io: move StreamReader and ReaderStream into tokio_util (#2788) Co-authored-by: Mikail Bagishov Co-authored-by: Eliza Weisman --- tokio-util/Cargo.toml | 3 +- tokio-util/src/cfg.rs | 10 ++ tokio-util/src/io/mod.rs | 13 ++ tokio-util/src/io/reader_stream.rs | 100 +++++++++++++++ tokio-util/src/io/stream_reader.rs | 181 ++++++++++++++++++++++++++++ tokio-util/src/lib.rs | 4 + tokio-util/tests/io_reader_stream.rs | 65 ++++++++++ tokio-util/tests/io_stream_reader.rs | 35 ++++++ tokio-util/tests/sync_cancellation_token.rs | 2 + tokio-util/tests/udp.rs | 2 + 10 files changed, 414 insertions(+), 1 deletion(-) create mode 100644 tokio-util/src/io/mod.rs create mode 100644 tokio-util/src/io/reader_stream.rs create mode 100644 tokio-util/src/io/stream_reader.rs create mode 100644 tokio-util/tests/io_reader_stream.rs create mode 100644 tokio-util/tests/io_stream_reader.rs (limited to 'tokio-util') diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index b47c9dfc..85b4e592 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -25,11 +25,12 @@ publish = false default = [] # Shorthand for enabling everything -full = ["codec", "udp", "compat"] +full = ["codec", "udp", "compat", "io"] compat = ["futures-io",] codec = ["tokio/stream"] udp = ["tokio/udp"] +io = [] [dependencies] tokio = { version = "0.3.0", path = "../tokio" } diff --git a/tokio-util/src/cfg.rs b/tokio-util/src/cfg.rs index 27e8c66a..2efa5f09 100644 --- a/tokio-util/src/cfg.rs +++ b/tokio-util/src/cfg.rs @@ -27,3 +27,13 @@ macro_rules! cfg_udp { )* } } + +macro_rules! cfg_io { + ($($item:item)*) => { + $( + #[cfg(feature = "io")] + #[cfg_attr(docsrs, doc(cfg(feature = "io")))] + $item + )* + } +} diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs new file mode 100644 index 00000000..53066c4e --- /dev/null +++ b/tokio-util/src/io/mod.rs @@ -0,0 +1,13 @@ +//! Helpers for IO related tasks. +//! +//! These types are often used in combination with hyper or reqwest, as they +//! allow converting between a hyper [`Body`] and [`AsyncRead`]. +//! +//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html +//! [`AsyncRead`]: tokio::io::AsyncRead + +mod reader_stream; +mod stream_reader; + +pub use self::reader_stream::ReaderStream; +pub use self::stream_reader::StreamReader; diff --git a/tokio-util/src/io/reader_stream.rs b/tokio-util/src/io/reader_stream.rs new file mode 100644 index 00000000..bde7ccee --- /dev/null +++ b/tokio-util/src/io/reader_stream.rs @@ -0,0 +1,100 @@ +use bytes::{Bytes, BytesMut}; +use futures_core::stream::Stream; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; + +const CAPACITY: usize = 4096; + +pin_project! { + /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks. + /// + /// This stream is fused. It performs the inverse operation of + /// [`StreamReader`]. + /// + /// # Example + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// use tokio::stream::StreamExt; + /// use tokio_util::io::ReaderStream; + /// + /// // Create a stream of data. + /// let data = b"hello, world!"; + /// let mut stream = ReaderStream::new(&data[..]); + /// + /// // Read all of the chunks into a vector. + /// let mut stream_contents = Vec::new(); + /// while let Some(chunk) = stream.next().await { + /// stream_contents.extend_from_slice(&chunk?); + /// } + /// + /// // Once the chunks are concatenated, we should have the + /// // original data. + /// assert_eq!(stream_contents, data); + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`StreamReader`]: crate::io::StreamReader + /// [`Stream`]: tokio::stream::Stream + #[derive(Debug)] + pub struct ReaderStream { + // Reader itself. + // + // This value is `None` if the stream has terminated. + #[pin] + reader: Option, + // Working buffer, used to optimize allocations. + buf: BytesMut, + } +} + +impl ReaderStream { + /// Convert an [`AsyncRead`] into a [`Stream`] with item type + /// `Result`. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: tokio::stream::Stream + pub fn new(reader: R) -> Self { + ReaderStream { + reader: Some(reader), + buf: BytesMut::new(), + } + } +} + +impl Stream for ReaderStream { + type Item = std::io::Result; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.as_mut().project(); + + let reader = match this.reader.as_pin_mut() { + Some(r) => r, + None => return Poll::Ready(None), + }; + + if this.buf.capacity() == 0 { + this.buf.reserve(CAPACITY); + } + + match reader.poll_read_buf(cx, &mut this.buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + self.project().reader.set(None); + Poll::Ready(Some(Err(err))) + } + Poll::Ready(Ok(0)) => { + self.project().reader.set(None); + Poll::Ready(None) + } + Poll::Ready(Ok(_)) => { + let chunk = this.buf.split(); + Poll::Ready(Some(Ok(chunk.freeze()))) + } + } + } +} diff --git a/tokio-util/src/io/stream_reader.rs b/tokio-util/src/io/stream_reader.rs new file mode 100644 index 00000000..5c3ab019 --- /dev/null +++ b/tokio-util/src/io/stream_reader.rs @@ -0,0 +1,181 @@ +use bytes::{Buf, BufMut}; +use futures_core::stream::Stream; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +pin_project! { + /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`]. + /// + /// This type performs the inverse operation of [`ReaderStream`]. + /// + /// # Example + /// + /// ``` + /// use bytes::Bytes; + /// use tokio::io::{AsyncReadExt, Result}; + /// use tokio_util::io::StreamReader; + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// + /// // Create a stream from an iterator. + /// let stream = tokio::stream::iter(vec![ + /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), + /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), + /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])), + /// ]); + /// + /// // Convert it to an AsyncRead. + /// let mut read = StreamReader::new(stream); + /// + /// // Read five bytes from the stream. + /// let mut buf = [0; 5]; + /// read.read_exact(&mut buf).await?; + /// assert_eq!(buf, [0, 1, 2, 3, 4]); + /// + /// // Read the rest of the current chunk. + /// assert_eq!(read.read(&mut buf).await?, 3); + /// assert_eq!(&buf[..3], [5, 6, 7]); + /// + /// // Read the next chunk. + /// assert_eq!(read.read(&mut buf).await?, 4); + /// assert_eq!(&buf[..4], [8, 9, 10, 11]); + /// + /// // We have now reached the end. + /// assert_eq!(read.read(&mut buf).await?, 0); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: tokio::stream::Stream + /// [`ReaderStream`]: crate::io::ReaderStream + #[derive(Debug)] + pub struct StreamReader { + #[pin] + inner: S, + chunk: Option, + } +} + +impl StreamReader +where + S: Stream>, + B: Buf, + E: Into, +{ + /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead). + /// + /// The item should be a [`Result`] with the ok variant being something that + /// implements the [`Buf`] trait (e.g. `Vec` or `Bytes`). The error + /// should be convertible into an [io error]. + /// + /// [`Result`]: std::result::Result + /// [`Buf`]: bytes::Buf + /// [io error]: std::io::Error + pub fn new(stream: S) -> Self { + Self { + inner: stream, + chunk: None, + } + } + + /// Do we have a chunk and is it non-empty? + fn has_chunk(self: Pin<&mut Self>) -> bool { + if let Some(chunk) = self.project().chunk { + chunk.remaining() > 0 + } else { + false + } + } +} + +impl AsyncRead for StreamReader +where + S: Stream>, + B: Buf, + E: Into, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + let inner_buf = match self.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => buf, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + let len = std::cmp::min(inner_buf.len(), buf.remaining()); + buf.append(&inner_buf[..len]); + + self.consume(len); + Poll::Ready(Ok(())) + } + fn poll_read_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut BM, + ) -> Poll> + where + Self: Sized, + { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let inner_buf = match self.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => buf, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + let len = std::cmp::min(inner_buf.len(), buf.remaining_mut()); + buf.put_slice(&inner_buf[..len]); + + self.consume(len); + Poll::Ready(Ok(len)) + } +} + +impl AsyncBufRead for StreamReader +where + S: Stream>, + B: Buf, + E: Into, +{ + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.as_mut().has_chunk() { + // This unwrap is very sad, but it can't be avoided. + let buf = self.project().chunk.as_ref().unwrap().bytes(); + return Poll::Ready(Ok(buf)); + } else { + match self.as_mut().project().inner.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + // Go around the loop in case the chunk is empty. + *self.as_mut().project().chunk = Some(chunk); + } + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), + Poll::Ready(None) => return Poll::Ready(Ok(&[])), + Poll::Pending => return Poll::Pending, + } + } + } + } + fn consume(self: Pin<&mut Self>, amt: usize) { + if amt > 0 { + self.project() + .chunk + .as_mut() + .expect("No chunk present") + .advance(amt); + } + } +} diff --git a/tokio-util/src/lib.rs b/tokio-util/src/lib.rs index 3e9a3b7e..49733c6a 100644 --- a/tokio-util/src/lib.rs +++ b/tokio-util/src/lib.rs @@ -38,6 +38,10 @@ cfg_compat! { pub mod compat; } +cfg_io! { + pub mod io; +} + pub mod context; pub mod sync; diff --git a/tokio-util/tests/io_reader_stream.rs b/tokio-util/tests/io_reader_stream.rs new file mode 100644 index 00000000..b906de09 --- /dev/null +++ b/tokio-util/tests/io_reader_stream.rs @@ -0,0 +1,65 @@ +#![warn(rust_2018_idioms)] + +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::stream::StreamExt; + +/// produces at most `remaining` zeros, that returns error. +/// each time it reads at most 31 byte. +struct Reader { + remaining: usize, +} + +impl AsyncRead for Reader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = Pin::into_inner(self); + assert_ne!(buf.remaining(), 0); + if this.remaining > 0 { + let n = std::cmp::min(this.remaining, buf.remaining()); + let n = std::cmp::min(n, 31); + for x in &mut buf.initialize_unfilled_to(n)[..n] { + *x = 0; + } + buf.add_filled(n); + this.remaining -= n; + Poll::Ready(Ok(())) + } else { + Poll::Ready(Err(std::io::Error::from_raw_os_error(22))) + } + } +} + +#[tokio::test] +async fn correct_behavior_on_errors() { + let reader = Reader { remaining: 8000 }; + let mut stream = tokio_util::io::ReaderStream::new(reader); + let mut zeros_received = 0; + let mut had_error = false; + loop { + let item = stream.next().await.unwrap(); + println!("{:?}", item); + match item { + Ok(bytes) => { + let bytes = &*bytes; + for byte in bytes { + assert_eq!(*byte, 0); + zeros_received += 1; + } + } + Err(_) => { + assert!(!had_error); + had_error = true; + break; + } + } + } + + assert!(had_error); + assert_eq!(zeros_received, 8000); + assert!(stream.next().await.is_none()); +} diff --git a/tokio-util/tests/io_stream_reader.rs b/tokio-util/tests/io_stream_reader.rs new file mode 100644 index 00000000..b0ed1d2d --- /dev/null +++ b/tokio-util/tests/io_stream_reader.rs @@ -0,0 +1,35 @@ +#![warn(rust_2018_idioms)] + +use bytes::Bytes; +use tokio::io::AsyncReadExt; +use tokio::stream::iter; +use tokio_util::io::StreamReader; + +#[tokio::test] +async fn test_stream_reader() -> std::io::Result<()> { + let stream = iter(vec![ + std::io::Result::Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[0, 1, 2, 3])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[4, 5, 6, 7])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[8, 9, 10, 11])), + Ok(Bytes::from_static(&[])), + ]); + + let mut read = StreamReader::new(stream); + + let mut buf = [0; 5]; + read.read_exact(&mut buf).await?; + assert_eq!(buf, [0, 1, 2, 3, 4]); + + assert_eq!(read.read(&mut buf).await?, 3); + assert_eq!(&buf[..3], [5, 6, 7]); + + assert_eq!(read.read(&mut buf).await?, 4); + assert_eq!(&buf[..4], [8, 9, 10, 11]); + + assert_eq!(read.read(&mut buf).await?, 0); + + Ok(()) +} diff --git a/tokio-util/tests/sync_cancellation_token.rs b/tokio-util/tests/sync_cancellation_token.rs index c65a6425..438e5d5e 100644 --- a/tokio-util/tests/sync_cancellation_token.rs +++ b/tokio-util/tests/sync_cancellation_token.rs @@ -1,3 +1,5 @@ +#![warn(rust_2018_idioms)] + use tokio::pin; use tokio_util::sync::CancellationToken; diff --git a/tokio-util/tests/udp.rs b/tokio-util/tests/udp.rs index d0320beb..4820ac72 100644 --- a/tokio-util/tests/udp.rs +++ b/tokio-util/tests/udp.rs @@ -1,3 +1,5 @@ +#![warn(rust_2018_idioms)] + use tokio::{net::UdpSocket, stream::StreamExt}; use tokio_util::codec::{Decoder, Encoder, LinesCodec}; use tokio_util::udp::UdpFramed; -- cgit v1.2.3