summaryrefslogtreecommitdiffstats
path: root/tokio-util
diff options
context:
space:
mode:
authorAlice Ryhl <alice@ryhl.io>2020-09-08 09:12:32 +0200
committerGitHub <noreply@github.com>2020-09-08 09:12:32 +0200
commit37f405bd3b06921598d298b0ba5b9296656454bf (patch)
tree3098806c15ddae632e5f02706828d060608fea6c /tokio-util
parent7c254eca446e56bbc41cbc309c2588f2d241f46a (diff)
io: move StreamReader and ReaderStream into tokio_util (#2788)
Co-authored-by: Mikail Bagishov <bagishov.mikail@yandex.ru> Co-authored-by: Eliza Weisman <eliza@buoyant.io>
Diffstat (limited to 'tokio-util')
-rw-r--r--tokio-util/Cargo.toml3
-rw-r--r--tokio-util/src/cfg.rs10
-rw-r--r--tokio-util/src/io/mod.rs13
-rw-r--r--tokio-util/src/io/reader_stream.rs100
-rw-r--r--tokio-util/src/io/stream_reader.rs181
-rw-r--r--tokio-util/src/lib.rs4
-rw-r--r--tokio-util/tests/io_reader_stream.rs65
-rw-r--r--tokio-util/tests/io_stream_reader.rs35
-rw-r--r--tokio-util/tests/sync_cancellation_token.rs2
-rw-r--r--tokio-util/tests/udp.rs2
10 files changed, 414 insertions, 1 deletions
diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml
index b47c9dfc..85b4e592 100644
--- a/tokio-util/Cargo.toml
+++ b/tokio-util/Cargo.toml
@@ -25,11 +25,12 @@ publish = false
default = []
# Shorthand for enabling everything
-full = ["codec", "udp", "compat"]
+full = ["codec", "udp", "compat", "io"]
compat = ["futures-io",]
codec = ["tokio/stream"]
udp = ["tokio/udp"]
+io = []
[dependencies]
tokio = { version = "0.3.0", path = "../tokio" }
diff --git a/tokio-util/src/cfg.rs b/tokio-util/src/cfg.rs
index 27e8c66a..2efa5f09 100644
--- a/tokio-util/src/cfg.rs
+++ b/tokio-util/src/cfg.rs
@@ -27,3 +27,13 @@ macro_rules! cfg_udp {
)*
}
}
+
+macro_rules! cfg_io {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "io")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "io")))]
+ $item
+ )*
+ }
+}
diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs
new file mode 100644
index 00000000..53066c4e
--- /dev/null
+++ b/tokio-util/src/io/mod.rs
@@ -0,0 +1,13 @@
+//! Helpers for IO related tasks.
+//!
+//! These types are often used in combination with hyper or reqwest, as they
+//! allow converting between a hyper [`Body`] and [`AsyncRead`].
+//!
+//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html
+//! [`AsyncRead`]: tokio::io::AsyncRead
+
+mod reader_stream;
+mod stream_reader;
+
+pub use self::reader_stream::ReaderStream;
+pub use self::stream_reader::StreamReader;
diff --git a/tokio-util/src/io/reader_stream.rs b/tokio-util/src/io/reader_stream.rs
new file mode 100644
index 00000000..bde7ccee
--- /dev/null
+++ b/tokio-util/src/io/reader_stream.rs
@@ -0,0 +1,100 @@
+use bytes::{Bytes, BytesMut};
+use futures_core::stream::Stream;
+use pin_project_lite::pin_project;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::AsyncRead;
+
+const CAPACITY: usize = 4096;
+
+pin_project! {
+ /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks.
+ ///
+ /// This stream is fused. It performs the inverse operation of
+ /// [`StreamReader`].
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// use tokio::stream::StreamExt;
+ /// use tokio_util::io::ReaderStream;
+ ///
+ /// // Create a stream of data.
+ /// let data = b"hello, world!";
+ /// let mut stream = ReaderStream::new(&data[..]);
+ ///
+ /// // Read all of the chunks into a vector.
+ /// let mut stream_contents = Vec::new();
+ /// while let Some(chunk) = stream.next().await {
+ /// stream_contents.extend_from_slice(&chunk?);
+ /// }
+ ///
+ /// // Once the chunks are concatenated, we should have the
+ /// // original data.
+ /// assert_eq!(stream_contents, data);
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`StreamReader`]: crate::io::StreamReader
+ /// [`Stream`]: tokio::stream::Stream
+ #[derive(Debug)]
+ pub struct ReaderStream<R> {
+ // Reader itself.
+ //
+ // This value is `None` if the stream has terminated.
+ #[pin]
+ reader: Option<R>,
+ // Working buffer, used to optimize allocations.
+ buf: BytesMut,
+ }
+}
+
+impl<R: AsyncRead> ReaderStream<R> {
+ /// Convert an [`AsyncRead`] into a [`Stream`] with item type
+ /// `Result<Bytes, std::io::Error>`.
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Stream`]: tokio::stream::Stream
+ pub fn new(reader: R) -> Self {
+ ReaderStream {
+ reader: Some(reader),
+ buf: BytesMut::new(),
+ }
+ }
+}
+
+impl<R: AsyncRead> Stream for ReaderStream<R> {
+ type Item = std::io::Result<Bytes>;
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let mut this = self.as_mut().project();
+
+ let reader = match this.reader.as_pin_mut() {
+ Some(r) => r,
+ None => return Poll::Ready(None),
+ };
+
+ if this.buf.capacity() == 0 {
+ this.buf.reserve(CAPACITY);
+ }
+
+ match reader.poll_read_buf(cx, &mut this.buf) {
+ Poll::Pending => Poll::Pending,
+ Poll::Ready(Err(err)) => {
+ self.project().reader.set(None);
+ Poll::Ready(Some(Err(err)))
+ }
+ Poll::Ready(Ok(0)) => {
+ self.project().reader.set(None);
+ Poll::Ready(None)
+ }
+ Poll::Ready(Ok(_)) => {
+ let chunk = this.buf.split();
+ Poll::Ready(Some(Ok(chunk.freeze())))
+ }
+ }
+ }
+}
diff --git a/tokio-util/src/io/stream_reader.rs b/tokio-util/src/io/stream_reader.rs
new file mode 100644
index 00000000..5c3ab019
--- /dev/null
+++ b/tokio-util/src/io/stream_reader.rs
@@ -0,0 +1,181 @@
+use bytes::{Buf, BufMut};
+use futures_core::stream::Stream;
+use pin_project_lite::pin_project;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
+
+pin_project! {
+ /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
+ ///
+ /// This type performs the inverse operation of [`ReaderStream`].
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use bytes::Bytes;
+ /// use tokio::io::{AsyncReadExt, Result};
+ /// use tokio_util::io::StreamReader;
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ ///
+ /// // Create a stream from an iterator.
+ /// let stream = tokio::stream::iter(vec![
+ /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
+ /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
+ /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
+ /// ]);
+ ///
+ /// // Convert it to an AsyncRead.
+ /// let mut read = StreamReader::new(stream);
+ ///
+ /// // Read five bytes from the stream.
+ /// let mut buf = [0; 5];
+ /// read.read_exact(&mut buf).await?;
+ /// assert_eq!(buf, [0, 1, 2, 3, 4]);
+ ///
+ /// // Read the rest of the current chunk.
+ /// assert_eq!(read.read(&mut buf).await?, 3);
+ /// assert_eq!(&buf[..3], [5, 6, 7]);
+ ///
+ /// // Read the next chunk.
+ /// assert_eq!(read.read(&mut buf).await?, 4);
+ /// assert_eq!(&buf[..4], [8, 9, 10, 11]);
+ ///
+ /// // We have now reached the end.
+ /// assert_eq!(read.read(&mut buf).await?, 0);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Stream`]: tokio::stream::Stream
+ /// [`ReaderStream`]: crate::io::ReaderStream
+ #[derive(Debug)]
+ pub struct StreamReader<S, B> {
+ #[pin]
+ inner: S,
+ chunk: Option<B>,
+ }
+}
+
+impl<S, B, E> StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead).
+ ///
+ /// The item should be a [`Result`] with the ok variant being something that
+ /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
+ /// should be convertible into an [io error].
+ ///
+ /// [`Result`]: std::result::Result
+ /// [`Buf`]: bytes::Buf
+ /// [io error]: std::io::Error
+ pub fn new(stream: S) -> Self {
+ Self {
+ inner: stream,
+ chunk: None,
+ }
+ }
+
+ /// Do we have a chunk and is it non-empty?
+ fn has_chunk(self: Pin<&mut Self>) -> bool {
+ if let Some(chunk) = self.project().chunk {
+ chunk.remaining() > 0
+ } else {
+ false
+ }
+ }
+}
+
+impl<S, B, E> AsyncRead for StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ if buf.remaining() == 0 {
+ return Poll::Ready(Ok(()));
+ }
+
+ let inner_buf = match self.as_mut().poll_fill_buf(cx) {
+ Poll::Ready(Ok(buf)) => buf,
+ Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+ Poll::Pending => return Poll::Pending,
+ };
+ let len = std::cmp::min(inner_buf.len(), buf.remaining());
+ buf.append(&inner_buf[..len]);
+
+ self.consume(len);
+ Poll::Ready(Ok(()))
+ }
+ fn poll_read_buf<BM: BufMut>(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut BM,
+ ) -> Poll<io::Result<usize>>
+ where
+ Self: Sized,
+ {
+ if !buf.has_remaining_mut() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let inner_buf = match self.as_mut().poll_fill_buf(cx) {
+ Poll::Ready(Ok(buf)) => buf,
+ Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+ Poll::Pending => return Poll::Pending,
+ };
+ let len = std::cmp::min(inner_buf.len(), buf.remaining_mut());
+ buf.put_slice(&inner_buf[..len]);
+
+ self.consume(len);
+ Poll::Ready(Ok(len))
+ }
+}
+
+impl<S, B, E> AsyncBufRead for StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
+ loop {
+ if self.as_mut().has_chunk() {
+ // This unwrap is very sad, but it can't be avoided.
+ let buf = self.project().chunk.as_ref().unwrap().bytes();
+ return Poll::Ready(Ok(buf));
+ } else {
+ match self.as_mut().project().inner.poll_next(cx) {
+ Poll::Ready(Some(Ok(chunk))) => {
+ // Go around the loop in case the chunk is empty.
+ *self.as_mut().project().chunk = Some(chunk);
+ }
+ Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
+ Poll::Ready(None) => return Poll::Ready(Ok(&[])),
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+ }
+ }
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ if amt > 0 {
+ self.project()
+ .chunk
+ .as_mut()
+ .expect("No chunk present")
+ .advance(amt);
+ }
+ }
+}
diff --git a/tokio-util/src/lib.rs b/tokio-util/src/lib.rs
index 3e9a3b7e..49733c6a 100644
--- a/tokio-util/src/lib.rs
+++ b/tokio-util/src/lib.rs
@@ -38,6 +38,10 @@ cfg_compat! {
pub mod compat;
}
+cfg_io! {
+ pub mod io;
+}
+
pub mod context;
pub mod sync;
diff --git a/tokio-util/tests/io_reader_stream.rs b/tokio-util/tests/io_reader_stream.rs
new file mode 100644
index 00000000..b906de09
--- /dev/null
+++ b/tokio-util/tests/io_reader_stream.rs
@@ -0,0 +1,65 @@
+#![warn(rust_2018_idioms)]
+
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::{AsyncRead, ReadBuf};
+use tokio::stream::StreamExt;
+
+/// produces at most `remaining` zeros, that returns error.
+/// each time it reads at most 31 byte.
+struct Reader {
+ remaining: usize,
+}
+
+impl AsyncRead for Reader {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ let this = Pin::into_inner(self);
+ assert_ne!(buf.remaining(), 0);
+ if this.remaining > 0 {
+ let n = std::cmp::min(this.remaining, buf.remaining());
+ let n = std::cmp::min(n, 31);
+ for x in &mut buf.initialize_unfilled_to(n)[..n] {
+ *x = 0;
+ }
+ buf.add_filled(n);
+ this.remaining -= n;
+ Poll::Ready(Ok(()))
+ } else {
+ Poll::Ready(Err(std::io::Error::from_raw_os_error(22)))
+ }
+ }
+}
+
+#[tokio::test]
+async fn correct_behavior_on_errors() {
+ let reader = Reader { remaining: 8000 };
+ let mut stream = tokio_util::io::ReaderStream::new(reader);
+ let mut zeros_received = 0;
+ let mut had_error = false;
+ loop {
+ let item = stream.next().await.unwrap();
+ println!("{:?}", item);
+ match item {
+ Ok(bytes) => {
+ let bytes = &*bytes;
+ for byte in bytes {
+ assert_eq!(*byte, 0);
+ zeros_received += 1;
+ }
+ }
+ Err(_) => {
+ assert!(!had_error);
+ had_error = true;
+ break;
+ }
+ }
+ }
+
+ assert!(had_error);
+ assert_eq!(zeros_received, 8000);
+ assert!(stream.next().await.is_none());
+}
diff --git a/tokio-util/tests/io_stream_reader.rs b/tokio-util/tests/io_stream_reader.rs
new file mode 100644
index 00000000..b0ed1d2d
--- /dev/null
+++ b/tokio-util/tests/io_stream_reader.rs
@@ -0,0 +1,35 @@
+#![warn(rust_2018_idioms)]
+
+use bytes::Bytes;
+use tokio::io::AsyncReadExt;
+use tokio::stream::iter;
+use tokio_util::io::StreamReader;
+
+#[tokio::test]
+async fn test_stream_reader() -> std::io::Result<()> {
+ let stream = iter(vec![
+ std::io::Result::Ok(Bytes::from_static(&[])),
+ Ok(Bytes::from_static(&[0, 1, 2, 3])),
+ Ok(Bytes::from_static(&[])),
+ Ok(Bytes::from_static(&[4, 5, 6, 7])),
+ Ok(Bytes::from_static(&[])),
+ Ok(Bytes::from_static(&[8, 9, 10, 11])),
+ Ok(Bytes::from_static(&[])),
+ ]);
+
+ let mut read = StreamReader::new(stream);
+
+ let mut buf = [0; 5];
+ read.read_exact(&mut buf).await?;
+ assert_eq!(buf, [0, 1, 2, 3, 4]);
+
+ assert_eq!(read.read(&mut buf).await?, 3);
+ assert_eq!(&buf[..3], [5, 6, 7]);
+
+ assert_eq!(read.read(&mut buf).await?, 4);
+ assert_eq!(&buf[..4], [8, 9, 10, 11]);
+
+ assert_eq!(read.read(&mut buf).await?, 0);
+
+ Ok(())
+}
diff --git a/tokio-util/tests/sync_cancellation_token.rs b/tokio-util/tests/sync_cancellation_token.rs
index c65a6425..438e5d5e 100644
--- a/tokio-util/tests/sync_cancellation_token.rs
+++ b/tokio-util/tests/sync_cancellation_token.rs
@@ -1,3 +1,5 @@
+#![warn(rust_2018_idioms)]
+
use tokio::pin;
use tokio_util::sync::CancellationToken;
diff --git a/tokio-util/tests/udp.rs b/tokio-util/tests/udp.rs
index d0320beb..4820ac72 100644
--- a/tokio-util/tests/udp.rs
+++ b/tokio-util/tests/udp.rs
@@ -1,3 +1,5 @@
+#![warn(rust_2018_idioms)]
+
use tokio::{net::UdpSocket, stream::StreamExt};
use tokio_util::codec::{Decoder, Encoder, LinesCodec};
use tokio_util::udp::UdpFramed;