diff options
author | Carl Lerche <me@carllerche.com> | 2020-01-13 14:44:06 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-13 14:44:06 -0800 |
commit | eb1a8e1792b2c4b296be47a0681421c90bbdbf7a (patch) | |
tree | 602f3642f60167f8753c90cb170dcbaa34247ffb /tokio | |
parent | 5b091fa3f0c3a06047d02ca6892f75c3e15040df (diff) |
stream: add `StreamExt::collect()` (#2109)
Provides an asynchronous equivalent to `Iterator::collect()`. A sealed
`FromStream` trait is added. Stabilization is pending Rust supporting
`async` trait fns.
Diffstat (limited to 'tokio')
-rw-r--r-- | tokio/src/net/addr.rs | 9 | ||||
-rw-r--r-- | tokio/src/stream/collect.rs | 246 | ||||
-rw-r--r-- | tokio/src/stream/mod.rs | 76 | ||||
-rw-r--r-- | tokio/tests/stream_collect.rs | 171 |
4 files changed, 497 insertions, 5 deletions
diff --git a/tokio/src/net/addr.rs b/tokio/src/net/addr.rs index 8e3bf434..d8d89c40 100644 --- a/tokio/src/net/addr.rs +++ b/tokio/src/net/addr.rs @@ -14,12 +14,11 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV /// # Calling /// /// Currently, this trait is only used as an argument to Tokio functions that -/// need to reference a target socket address. +/// need to reference a target socket address. To perform a `SocketAddr` +/// conversion directly, use [`lookup_host()`](super::lookup_host()). /// -/// This trait is sealed and is intended to be opaque. Users of Tokio should -/// only use `ToSocketAddrs` in trait bounds and __must not__ attempt to call -/// the functions directly or reference associated types. Changing these is not -/// considered a breaking change. +/// This trait is sealed and is intended to be opaque. The details of the trait +/// will change. Stabilization is pending enhancements to the Rust langague. pub trait ToSocketAddrs: sealed::ToSocketAddrsPriv {} type ReadyFuture<T> = future::Ready<io::Result<T>>; diff --git a/tokio/src/stream/collect.rs b/tokio/src/stream/collect.rs new file mode 100644 index 00000000..e8a58147 --- /dev/null +++ b/tokio/src/stream/collect.rs @@ -0,0 +1,246 @@ +use crate::stream::Stream; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use core::future::Future; +use core::mem; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +// Do not export this struct until `FromStream` can be unsealed. +pin_project! { + /// Stream returned by the [`collect`](super::StreamExt::collect) method. + #[must_use = "streams do nothing unless polled"] + #[derive(Debug)] + pub struct Collect<T, U> + where + T: Stream, + U: FromStream<T::Item>, + { + #[pin] + stream: T, + collection: U::Collection, + } +} + +/// Convert from a [`Stream`](crate::stream::Stream). +/// +/// This trait is not intended to be used directly. Instead, call +/// [`StreamExt::collect()`](super::StreamExt::collect). +/// +/// # Implementing +/// +/// Currently, this trait may not be implemented by third parties. The trait is +/// sealed in order to make changes in the future. Stabilization is pending +/// enhancements to the Rust langague. +pub trait FromStream<T>: sealed::FromStreamPriv<T> {} + +impl<T, U> Collect<T, U> +where + T: Stream, + U: FromStream<T::Item>, +{ + pub(super) fn new(stream: T) -> Collect<T, U> { + let (lower, upper) = stream.size_hint(); + let collection = U::initialize(lower, upper); + + Collect { stream, collection } + } +} + +impl<T, U> Future for Collect<T, U> +where + T: Stream, + U: FromStream<T::Item>, +{ + type Output = U; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<U> { + use Poll::Ready; + + loop { + let mut me = self.as_mut().project(); + + let item = match ready!(me.stream.poll_next(cx)) { + Some(item) => item, + None => { + return Ready(U::finalize(&mut me.collection)); + } + }; + + if !U::extend(&mut me.collection, item) { + return Ready(U::finalize(&mut me.collection)); + } + } + } +} + +// ===== FromStream implementations + +impl FromStream<()> for () {} + +impl sealed::FromStreamPriv<()> for () { + type Collection = (); + + fn initialize(_lower: usize, _upper: Option<usize>) {} + + fn extend(_collection: &mut (), _item: ()) -> bool { + true + } + + fn finalize(_collection: &mut ()) {} +} + +impl<T: AsRef<str>> FromStream<T> for String {} + +impl<T: AsRef<str>> sealed::FromStreamPriv<T> for String { + type Collection = String; + + fn initialize(_lower: usize, _upper: Option<usize>) -> String { + String::new() + } + + fn extend(collection: &mut String, item: T) -> bool { + collection.push_str(item.as_ref()); + true + } + + fn finalize(collection: &mut String) -> String { + mem::replace(collection, String::new()) + } +} + +impl<T> FromStream<T> for Vec<T> {} + +impl<T> sealed::FromStreamPriv<T> for Vec<T> { + type Collection = Vec<T>; + + fn initialize(lower: usize, _upper: Option<usize>) -> Vec<T> { + Vec::with_capacity(lower) + } + + fn extend(collection: &mut Vec<T>, item: T) -> bool { + collection.push(item); + true + } + + fn finalize(collection: &mut Vec<T>) -> Vec<T> { + mem::replace(collection, vec![]) + } +} + +impl<T> FromStream<T> for Box<[T]> {} + +impl<T> sealed::FromStreamPriv<T> for Box<[T]> { + type Collection = Vec<T>; + + fn initialize(lower: usize, upper: Option<usize>) -> Vec<T> { + <Vec<T> as sealed::FromStreamPriv<T>>::initialize(lower, upper) + } + + fn extend(collection: &mut Vec<T>, item: T) -> bool { + <Vec<T> as sealed::FromStreamPriv<T>>::extend(collection, item) + } + + fn finalize(collection: &mut Vec<T>) -> Box<[T]> { + <Vec<T> as sealed::FromStreamPriv<T>>::finalize(collection).into_boxed_slice() + } +} + +impl<T, U, E> FromStream<Result<T, E>> for Result<U, E> where U: FromStream<T> {} + +impl<T, U, E> sealed::FromStreamPriv<Result<T, E>> for Result<U, E> +where + U: FromStream<T>, +{ + type Collection = Result<U::Collection, E>; + + fn initialize(lower: usize, upper: Option<usize>) -> Result<U::Collection, E> { + Ok(U::initialize(lower, upper)) + } + + fn extend(collection: &mut Self::Collection, item: Result<T, E>) -> bool { + assert!(collection.is_ok()); + match item { + Ok(item) => { + let collection = collection.as_mut().ok().expect("invalid state"); + U::extend(collection, item) + } + Err(err) => { + *collection = Err(err); + false + } + } + } + + fn finalize(collection: &mut Self::Collection) -> Result<U, E> { + if let Ok(collection) = collection.as_mut() { + Ok(U::finalize(collection)) + } else { + let res = mem::replace(collection, Ok(U::initialize(0, Some(0)))); + + if let Err(err) = res { + Err(err) + } else { + unreachable!(); + } + } + } +} + +impl<T: Buf> FromStream<T> for Bytes {} + +impl<T: Buf> sealed::FromStreamPriv<T> for Bytes { + type Collection = BytesMut; + + fn initialize(_lower: usize, _upper: Option<usize>) -> BytesMut { + BytesMut::new() + } + + fn extend(collection: &mut BytesMut, item: T) -> bool { + collection.put(item); + true + } + + fn finalize(collection: &mut BytesMut) -> Bytes { + mem::replace(collection, BytesMut::new()).freeze() + } +} + +impl<T: Buf> FromStream<T> for BytesMut {} + +impl<T: Buf> sealed::FromStreamPriv<T> for BytesMut { + type Collection = BytesMut; + + fn initialize(_lower: usize, _upper: Option<usize>) -> BytesMut { + BytesMut::new() + } + + fn extend(collection: &mut BytesMut, item: T) -> bool { + collection.put(item); + true + } + + fn finalize(collection: &mut BytesMut) -> BytesMut { + mem::replace(collection, BytesMut::new()) + } +} + +pub(crate) mod sealed { + #[doc(hidden)] + pub trait FromStreamPriv<T> { + /// Intermediate type used during collection process + type Collection; + + /// Initialize the collection + fn initialize(lower: usize, upper: Option<usize>) -> Self::Collection; + + /// Extend the collection with the received item + /// + /// Return `true` to continue streaming, `false` complete collection. + fn extend(collection: &mut Self::Collection, item: T) -> bool; + + /// Finalize collection into target type. + fn finalize(collection: &mut Self::Collection) -> Self; + } +} diff --git a/tokio/src/stream/mod.rs b/tokio/src/stream/mod.rs index 9be1b102..081fe817 100644 --- a/tokio/src/stream/mod.rs +++ b/tokio/src/stream/mod.rs @@ -13,6 +13,10 @@ use any::AnyFuture; mod chain; use chain::Chain; +mod collect; +use collect::Collect; +pub use collect::FromStream; + mod empty; pub use empty::{empty, Empty}; @@ -577,6 +581,78 @@ pub trait StreamExt: Stream { { Chain::new(self, other) } + + /// Drain stream pushing all emitted values into a collection. + /// + /// `collect` streams all values, awaiting as needed. Values are pushed into + /// a collection. A number of different target collection types are + /// supported, including [`Vec`](std::vec::Vec), + /// [`String`](std::string::String), and [`Bytes`](bytes::Bytes). + /// + /// # `Result` + /// + /// `collect()` can also be used with streams of type `Result<T, E>` where + /// `T: FromStream<_>`. In this case, `collect()` will stream as long as + /// values yielded from the stream are `Ok(_)`. If `Err(_)` is encountered, + /// streaming is terminated and `collect()` returns the `Err`. + /// + /// # Notes + /// + /// `FromStream` is currently a sealed trait. Stabilization is pending + /// enhancements to the Rust langague. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// use tokio::stream::{self, StreamExt}; + /// + /// #[tokio::main] + /// async fn main() { + /// let doubled: Vec<i32> = + /// stream::iter(vec![1, 2, 3]) + /// .map(|x| x * 2) + /// .collect() + /// .await; + /// + /// assert_eq!(vec![2, 4, 6], doubled); + /// } + /// ``` + /// + /// Collecting a stream of `Result` values + /// + /// ``` + /// use tokio::stream::{self, StreamExt}; + /// + /// #[tokio::main] + /// async fn main() { + /// // A stream containing only `Ok` values will be collected + /// let values: Result<Vec<i32>, &str> = + /// stream::iter(vec![Ok(1), Ok(2), Ok(3)]) + /// .collect() + /// .await; + /// + /// assert_eq!(Ok(vec![1, 2, 3]), values); + /// + /// // A stream containing `Err` values will return the first error. + /// let results = vec![Ok(1), Err("no"), Ok(2), Ok(3), Err("nein")]; + /// + /// let values: Result<Vec<i32>, &str> = + /// stream::iter(results) + /// .collect() + /// .await; + /// + /// assert_eq!(Err("no"), values); + /// } + /// ``` + fn collect<T>(self) -> Collect<Self, T> + where + T: FromStream<Self::Item>, + Self: Sized, + { + Collect::new(self) + } } impl<St: ?Sized> StreamExt for St where St: Stream {} diff --git a/tokio/tests/stream_collect.rs b/tokio/tests/stream_collect.rs new file mode 100644 index 00000000..a4bee0d1 --- /dev/null +++ b/tokio/tests/stream_collect.rs @@ -0,0 +1,171 @@ +use tokio::stream::{self, StreamExt}; +use tokio::sync::mpsc; +use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; + +use bytes::{Bytes, BytesMut}; + +#[tokio::test] +async fn empty_unit() { + // Drains the stream. + let mut iter = vec![(), (), ()].into_iter(); + let _: () = stream::iter(&mut iter).collect().await; + assert!(iter.next().is_none()); +} + +#[tokio::test] +async fn empty_vec() { + let coll: Vec<u32> = stream::empty().collect().await; + assert!(coll.is_empty()); +} + +#[tokio::test] +async fn empty_box_slice() { + let coll: Box<[u32]> = stream::empty().collect().await; + assert!(coll.is_empty()); +} + +#[tokio::test] +async fn empty_bytes() { + let coll: Bytes = stream::empty::<&[u8]>().collect().await; + assert!(coll.is_empty()); +} + +#[tokio::test] +async fn empty_bytes_mut() { + let coll: BytesMut = stream::empty::<&[u8]>().collect().await; + assert!(coll.is_empty()); +} + +#[tokio::test] +async fn empty_string() { + let coll: String = stream::empty::<&str>().collect().await; + assert!(coll.is_empty()); +} + +#[tokio::test] +async fn empty_result() { + let coll: Result<Vec<u32>, &str> = stream::empty().collect().await; + assert_eq!(Ok(vec![]), coll); +} + +#[tokio::test] +async fn collect_vec_items() { + let (tx, rx) = mpsc::unbounded_channel(); + let mut fut = task::spawn(rx.collect::<Vec<i32>>()); + + assert_pending!(fut.poll()); + + tx.send(1).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(2).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(vec![1, 2], coll); +} + +#[tokio::test] +async fn collect_string_items() { + let (tx, rx) = mpsc::unbounded_channel(); + let mut fut = task::spawn(rx.collect::<String>()); + + assert_pending!(fut.poll()); + + tx.send("hello ".to_string()).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send("world".to_string()).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!("hello world", coll); +} + +#[tokio::test] +async fn collect_str_items() { + let (tx, rx) = mpsc::unbounded_channel(); + let mut fut = task::spawn(rx.collect::<String>()); + + assert_pending!(fut.poll()); + + tx.send("hello ").unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send("world").unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!("hello world", coll); +} + +#[tokio::test] +async fn collect_bytes() { + let (tx, rx) = mpsc::unbounded_channel(); + let mut fut = task::spawn(rx.collect::<Bytes>()); + + assert_pending!(fut.poll()); + + tx.send(&b"hello "[..]).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(&b"world"[..]).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(&b"hello world"[..], coll); +} + +#[tokio::test] +async fn collect_results_ok() { + let (tx, rx) = mpsc::unbounded_channel(); + let mut fut = task::spawn(rx.collect::<Result<String, &str>>()); + + assert_pending!(fut.poll()); + + tx.send(Ok("hello ")).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(Ok("world")).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready_ok!(fut.poll()); + assert_eq!("hello world", coll); +} + +#[tokio::test] +async fn collect_results_err() { + let (tx, rx) = mpsc::unbounded_channel(); + let mut fut = task::spawn(rx.collect::<Result<String, &str>>()); + + assert_pending!(fut.poll()); + + tx.send(Ok("hello ")).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(Err("oh no")).unwrap(); + assert!(fut.is_woken()); + let err = assert_ready_err!(fut.poll()); + assert_eq!("oh no", err); +} |