summaryrefslogtreecommitdiffstats
path: root/tokio
diff options
context:
space:
mode:
authorCarl Lerche <me@carllerche.com>2020-01-13 14:44:06 -0800
committerGitHub <noreply@github.com>2020-01-13 14:44:06 -0800
commiteb1a8e1792b2c4b296be47a0681421c90bbdbf7a (patch)
tree602f3642f60167f8753c90cb170dcbaa34247ffb /tokio
parent5b091fa3f0c3a06047d02ca6892f75c3e15040df (diff)
stream: add `StreamExt::collect()` (#2109)
Provides an asynchronous equivalent to `Iterator::collect()`. A sealed `FromStream` trait is added. Stabilization is pending Rust supporting `async` trait fns.
Diffstat (limited to 'tokio')
-rw-r--r--tokio/src/net/addr.rs9
-rw-r--r--tokio/src/stream/collect.rs246
-rw-r--r--tokio/src/stream/mod.rs76
-rw-r--r--tokio/tests/stream_collect.rs171
4 files changed, 497 insertions, 5 deletions
diff --git a/tokio/src/net/addr.rs b/tokio/src/net/addr.rs
index 8e3bf434..d8d89c40 100644
--- a/tokio/src/net/addr.rs
+++ b/tokio/src/net/addr.rs
@@ -14,12 +14,11 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV
/// # Calling
///
/// Currently, this trait is only used as an argument to Tokio functions that
-/// need to reference a target socket address.
+/// need to reference a target socket address. To perform a `SocketAddr`
+/// conversion directly, use [`lookup_host()`](super::lookup_host()).
///
-/// This trait is sealed and is intended to be opaque. Users of Tokio should
-/// only use `ToSocketAddrs` in trait bounds and __must not__ attempt to call
-/// the functions directly or reference associated types. Changing these is not
-/// considered a breaking change.
+/// This trait is sealed and is intended to be opaque. The details of the trait
+/// will change. Stabilization is pending enhancements to the Rust langague.
pub trait ToSocketAddrs: sealed::ToSocketAddrsPriv {}
type ReadyFuture<T> = future::Ready<io::Result<T>>;
diff --git a/tokio/src/stream/collect.rs b/tokio/src/stream/collect.rs
new file mode 100644
index 00000000..e8a58147
--- /dev/null
+++ b/tokio/src/stream/collect.rs
@@ -0,0 +1,246 @@
+use crate::stream::Stream;
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use core::future::Future;
+use core::mem;
+use core::pin::Pin;
+use core::task::{Context, Poll};
+use pin_project_lite::pin_project;
+
+// Do not export this struct until `FromStream` can be unsealed.
+pin_project! {
+ /// Stream returned by the [`collect`](super::StreamExt::collect) method.
+ #[must_use = "streams do nothing unless polled"]
+ #[derive(Debug)]
+ pub struct Collect<T, U>
+ where
+ T: Stream,
+ U: FromStream<T::Item>,
+ {
+ #[pin]
+ stream: T,
+ collection: U::Collection,
+ }
+}
+
+/// Convert from a [`Stream`](crate::stream::Stream).
+///
+/// This trait is not intended to be used directly. Instead, call
+/// [`StreamExt::collect()`](super::StreamExt::collect).
+///
+/// # Implementing
+///
+/// Currently, this trait may not be implemented by third parties. The trait is
+/// sealed in order to make changes in the future. Stabilization is pending
+/// enhancements to the Rust langague.
+pub trait FromStream<T>: sealed::FromStreamPriv<T> {}
+
+impl<T, U> Collect<T, U>
+where
+ T: Stream,
+ U: FromStream<T::Item>,
+{
+ pub(super) fn new(stream: T) -> Collect<T, U> {
+ let (lower, upper) = stream.size_hint();
+ let collection = U::initialize(lower, upper);
+
+ Collect { stream, collection }
+ }
+}
+
+impl<T, U> Future for Collect<T, U>
+where
+ T: Stream,
+ U: FromStream<T::Item>,
+{
+ type Output = U;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<U> {
+ use Poll::Ready;
+
+ loop {
+ let mut me = self.as_mut().project();
+
+ let item = match ready!(me.stream.poll_next(cx)) {
+ Some(item) => item,
+ None => {
+ return Ready(U::finalize(&mut me.collection));
+ }
+ };
+
+ if !U::extend(&mut me.collection, item) {
+ return Ready(U::finalize(&mut me.collection));
+ }
+ }
+ }
+}
+
+// ===== FromStream implementations
+
+impl FromStream<()> for () {}
+
+impl sealed::FromStreamPriv<()> for () {
+ type Collection = ();
+
+ fn initialize(_lower: usize, _upper: Option<usize>) {}
+
+ fn extend(_collection: &mut (), _item: ()) -> bool {
+ true
+ }
+
+ fn finalize(_collection: &mut ()) {}
+}
+
+impl<T: AsRef<str>> FromStream<T> for String {}
+
+impl<T: AsRef<str>> sealed::FromStreamPriv<T> for String {
+ type Collection = String;
+
+ fn initialize(_lower: usize, _upper: Option<usize>) -> String {
+ String::new()
+ }
+
+ fn extend(collection: &mut String, item: T) -> bool {
+ collection.push_str(item.as_ref());
+ true
+ }
+
+ fn finalize(collection: &mut String) -> String {
+ mem::replace(collection, String::new())
+ }
+}
+
+impl<T> FromStream<T> for Vec<T> {}
+
+impl<T> sealed::FromStreamPriv<T> for Vec<T> {
+ type Collection = Vec<T>;
+
+ fn initialize(lower: usize, _upper: Option<usize>) -> Vec<T> {
+ Vec::with_capacity(lower)
+ }
+
+ fn extend(collection: &mut Vec<T>, item: T) -> bool {
+ collection.push(item);
+ true
+ }
+
+ fn finalize(collection: &mut Vec<T>) -> Vec<T> {
+ mem::replace(collection, vec![])
+ }
+}
+
+impl<T> FromStream<T> for Box<[T]> {}
+
+impl<T> sealed::FromStreamPriv<T> for Box<[T]> {
+ type Collection = Vec<T>;
+
+ fn initialize(lower: usize, upper: Option<usize>) -> Vec<T> {
+ <Vec<T> as sealed::FromStreamPriv<T>>::initialize(lower, upper)
+ }
+
+ fn extend(collection: &mut Vec<T>, item: T) -> bool {
+ <Vec<T> as sealed::FromStreamPriv<T>>::extend(collection, item)
+ }
+
+ fn finalize(collection: &mut Vec<T>) -> Box<[T]> {
+ <Vec<T> as sealed::FromStreamPriv<T>>::finalize(collection).into_boxed_slice()
+ }
+}
+
+impl<T, U, E> FromStream<Result<T, E>> for Result<U, E> where U: FromStream<T> {}
+
+impl<T, U, E> sealed::FromStreamPriv<Result<T, E>> for Result<U, E>
+where
+ U: FromStream<T>,
+{
+ type Collection = Result<U::Collection, E>;
+
+ fn initialize(lower: usize, upper: Option<usize>) -> Result<U::Collection, E> {
+ Ok(U::initialize(lower, upper))
+ }
+
+ fn extend(collection: &mut Self::Collection, item: Result<T, E>) -> bool {
+ assert!(collection.is_ok());
+ match item {
+ Ok(item) => {
+ let collection = collection.as_mut().ok().expect("invalid state");
+ U::extend(collection, item)
+ }
+ Err(err) => {
+ *collection = Err(err);
+ false
+ }
+ }
+ }
+
+ fn finalize(collection: &mut Self::Collection) -> Result<U, E> {
+ if let Ok(collection) = collection.as_mut() {
+ Ok(U::finalize(collection))
+ } else {
+ let res = mem::replace(collection, Ok(U::initialize(0, Some(0))));
+
+ if let Err(err) = res {
+ Err(err)
+ } else {
+ unreachable!();
+ }
+ }
+ }
+}
+
+impl<T: Buf> FromStream<T> for Bytes {}
+
+impl<T: Buf> sealed::FromStreamPriv<T> for Bytes {
+ type Collection = BytesMut;
+
+ fn initialize(_lower: usize, _upper: Option<usize>) -> BytesMut {
+ BytesMut::new()
+ }
+
+ fn extend(collection: &mut BytesMut, item: T) -> bool {
+ collection.put(item);
+ true
+ }
+
+ fn finalize(collection: &mut BytesMut) -> Bytes {
+ mem::replace(collection, BytesMut::new()).freeze()
+ }
+}
+
+impl<T: Buf> FromStream<T> for BytesMut {}
+
+impl<T: Buf> sealed::FromStreamPriv<T> for BytesMut {
+ type Collection = BytesMut;
+
+ fn initialize(_lower: usize, _upper: Option<usize>) -> BytesMut {
+ BytesMut::new()
+ }
+
+ fn extend(collection: &mut BytesMut, item: T) -> bool {
+ collection.put(item);
+ true
+ }
+
+ fn finalize(collection: &mut BytesMut) -> BytesMut {
+ mem::replace(collection, BytesMut::new())
+ }
+}
+
+pub(crate) mod sealed {
+ #[doc(hidden)]
+ pub trait FromStreamPriv<T> {
+ /// Intermediate type used during collection process
+ type Collection;
+
+ /// Initialize the collection
+ fn initialize(lower: usize, upper: Option<usize>) -> Self::Collection;
+
+ /// Extend the collection with the received item
+ ///
+ /// Return `true` to continue streaming, `false` complete collection.
+ fn extend(collection: &mut Self::Collection, item: T) -> bool;
+
+ /// Finalize collection into target type.
+ fn finalize(collection: &mut Self::Collection) -> Self;
+ }
+}
diff --git a/tokio/src/stream/mod.rs b/tokio/src/stream/mod.rs
index 9be1b102..081fe817 100644
--- a/tokio/src/stream/mod.rs
+++ b/tokio/src/stream/mod.rs
@@ -13,6 +13,10 @@ use any::AnyFuture;
mod chain;
use chain::Chain;
+mod collect;
+use collect::Collect;
+pub use collect::FromStream;
+
mod empty;
pub use empty::{empty, Empty};
@@ -577,6 +581,78 @@ pub trait StreamExt: Stream {
{
Chain::new(self, other)
}
+
+ /// Drain stream pushing all emitted values into a collection.
+ ///
+ /// `collect` streams all values, awaiting as needed. Values are pushed into
+ /// a collection. A number of different target collection types are
+ /// supported, including [`Vec`](std::vec::Vec),
+ /// [`String`](std::string::String), and [`Bytes`](bytes::Bytes).
+ ///
+ /// # `Result`
+ ///
+ /// `collect()` can also be used with streams of type `Result<T, E>` where
+ /// `T: FromStream<_>`. In this case, `collect()` will stream as long as
+ /// values yielded from the stream are `Ok(_)`. If `Err(_)` is encountered,
+ /// streaming is terminated and `collect()` returns the `Err`.
+ ///
+ /// # Notes
+ ///
+ /// `FromStream` is currently a sealed trait. Stabilization is pending
+ /// enhancements to the Rust langague.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage:
+ ///
+ /// ```
+ /// use tokio::stream::{self, StreamExt};
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let doubled: Vec<i32> =
+ /// stream::iter(vec![1, 2, 3])
+ /// .map(|x| x * 2)
+ /// .collect()
+ /// .await;
+ ///
+ /// assert_eq!(vec![2, 4, 6], doubled);
+ /// }
+ /// ```
+ ///
+ /// Collecting a stream of `Result` values
+ ///
+ /// ```
+ /// use tokio::stream::{self, StreamExt};
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// // A stream containing only `Ok` values will be collected
+ /// let values: Result<Vec<i32>, &str> =
+ /// stream::iter(vec![Ok(1), Ok(2), Ok(3)])
+ /// .collect()
+ /// .await;
+ ///
+ /// assert_eq!(Ok(vec![1, 2, 3]), values);
+ ///
+ /// // A stream containing `Err` values will return the first error.
+ /// let results = vec![Ok(1), Err("no"), Ok(2), Ok(3), Err("nein")];
+ ///
+ /// let values: Result<Vec<i32>, &str> =
+ /// stream::iter(results)
+ /// .collect()
+ /// .await;
+ ///
+ /// assert_eq!(Err("no"), values);
+ /// }
+ /// ```
+ fn collect<T>(self) -> Collect<Self, T>
+ where
+ T: FromStream<Self::Item>,
+ Self: Sized,
+ {
+ Collect::new(self)
+ }
}
impl<St: ?Sized> StreamExt for St where St: Stream {}
diff --git a/tokio/tests/stream_collect.rs b/tokio/tests/stream_collect.rs
new file mode 100644
index 00000000..a4bee0d1
--- /dev/null
+++ b/tokio/tests/stream_collect.rs
@@ -0,0 +1,171 @@
+use tokio::stream::{self, StreamExt};
+use tokio::sync::mpsc;
+use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task};
+
+use bytes::{Bytes, BytesMut};
+
+#[tokio::test]
+async fn empty_unit() {
+ // Drains the stream.
+ let mut iter = vec![(), (), ()].into_iter();
+ let _: () = stream::iter(&mut iter).collect().await;
+ assert!(iter.next().is_none());
+}
+
+#[tokio::test]
+async fn empty_vec() {
+ let coll: Vec<u32> = stream::empty().collect().await;
+ assert!(coll.is_empty());
+}
+
+#[tokio::test]
+async fn empty_box_slice() {
+ let coll: Box<[u32]> = stream::empty().collect().await;
+ assert!(coll.is_empty());
+}
+
+#[tokio::test]
+async fn empty_bytes() {
+ let coll: Bytes = stream::empty::<&[u8]>().collect().await;
+ assert!(coll.is_empty());
+}
+
+#[tokio::test]
+async fn empty_bytes_mut() {
+ let coll: BytesMut = stream::empty::<&[u8]>().collect().await;
+ assert!(coll.is_empty());
+}
+
+#[tokio::test]
+async fn empty_string() {
+ let coll: String = stream::empty::<&str>().collect().await;
+ assert!(coll.is_empty());
+}
+
+#[tokio::test]
+async fn empty_result() {
+ let coll: Result<Vec<u32>, &str> = stream::empty().collect().await;
+ assert_eq!(Ok(vec![]), coll);
+}
+
+#[tokio::test]
+async fn collect_vec_items() {
+ let (tx, rx) = mpsc::unbounded_channel();
+ let mut fut = task::spawn(rx.collect::<Vec<i32>>());
+
+ assert_pending!(fut.poll());
+
+ tx.send(1).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ tx.send(2).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ drop(tx);
+ assert!(fut.is_woken());
+ let coll = assert_ready!(fut.poll());
+ assert_eq!(vec![1, 2], coll);
+}
+
+#[tokio::test]
+async fn collect_string_items() {
+ let (tx, rx) = mpsc::unbounded_channel();
+ let mut fut = task::spawn(rx.collect::<String>());
+
+ assert_pending!(fut.poll());
+
+ tx.send("hello ".to_string()).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ tx.send("world".to_string()).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ drop(tx);
+ assert!(fut.is_woken());
+ let coll = assert_ready!(fut.poll());
+ assert_eq!("hello world", coll);
+}
+
+#[tokio::test]
+async fn collect_str_items() {
+ let (tx, rx) = mpsc::unbounded_channel();
+ let mut fut = task::spawn(rx.collect::<String>());
+
+ assert_pending!(fut.poll());
+
+ tx.send("hello ").unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ tx.send("world").unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ drop(tx);
+ assert!(fut.is_woken());
+ let coll = assert_ready!(fut.poll());
+ assert_eq!("hello world", coll);
+}
+
+#[tokio::test]
+async fn collect_bytes() {
+ let (tx, rx) = mpsc::unbounded_channel();
+ let mut fut = task::spawn(rx.collect::<Bytes>());
+
+ assert_pending!(fut.poll());
+
+ tx.send(&b"hello "[..]).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ tx.send(&b"world"[..]).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ drop(tx);
+ assert!(fut.is_woken());
+ let coll = assert_ready!(fut.poll());
+ assert_eq!(&b"hello world"[..], coll);
+}
+
+#[tokio::test]
+async fn collect_results_ok() {
+ let (tx, rx) = mpsc::unbounded_channel();
+ let mut fut = task::spawn(rx.collect::<Result<String, &str>>());
+
+ assert_pending!(fut.poll());
+
+ tx.send(Ok("hello ")).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ tx.send(Ok("world")).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ drop(tx);
+ assert!(fut.is_woken());
+ let coll = assert_ready_ok!(fut.poll());
+ assert_eq!("hello world", coll);
+}
+
+#[tokio::test]
+async fn collect_results_err() {
+ let (tx, rx) = mpsc::unbounded_channel();
+ let mut fut = task::spawn(rx.collect::<Result<String, &str>>());
+
+ assert_pending!(fut.poll());
+
+ tx.send(Ok("hello ")).unwrap();
+ assert!(fut.is_woken());
+ assert_pending!(fut.poll());
+
+ tx.send(Err("oh no")).unwrap();
+ assert!(fut.is_woken());
+ let err = assert_ready_err!(fut.poll());
+ assert_eq!("oh no", err);
+}