From 5bf06f2b5a81ae7b5b8adfe4a44fab033f4156cf Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Fri, 24 Jan 2020 20:26:55 -0800 Subject: future: provide try_join! macro (#2169) Provides a `try_join!` macro that supports concurrently driving multiple `Result` futures on the same task and await the completion of all the futures as `Ok` or the **first** `Err` future. --- tokio/src/future/maybe_done.rs | 2 +- tokio/src/macros/join.rs | 6 ++ tokio/src/macros/mod.rs | 3 + tokio/src/macros/try_join.rs | 131 +++++++++++++++++++++++++++++++++++++++++ tokio/tests/macros_join.rs | 6 +- tokio/tests/macros_try_join.rs | 100 +++++++++++++++++++++++++++++++ 6 files changed, 244 insertions(+), 4 deletions(-) create mode 100644 tokio/src/macros/try_join.rs create mode 100644 tokio/tests/macros_try_join.rs diff --git a/tokio/src/future/maybe_done.rs b/tokio/src/future/maybe_done.rs index e93af521..1e083ad7 100644 --- a/tokio/src/future/maybe_done.rs +++ b/tokio/src/future/maybe_done.rs @@ -30,7 +30,7 @@ impl MaybeDone { /// The output of this method will be [`Some`] if and only if the inner /// future has been completed and [`take_output`](MaybeDone::take_output) /// has not yet been called. - pub(crate) fn output_mut(self: Pin<&mut Self>) -> Option<&mut Fut::Output> { + pub fn output_mut(self: Pin<&mut Self>) -> Option<&mut Fut::Output> { unsafe { let this = self.get_unchecked_mut(); match this { diff --git a/tokio/src/macros/join.rs b/tokio/src/macros/join.rs index c164bb04..68da021e 100644 --- a/tokio/src/macros/join.rs +++ b/tokio/src/macros/join.rs @@ -8,6 +8,12 @@ /// concurrently on the same task. Each async expression evaluates to a future /// and the futures from each expression are multiplexed on the current task. /// +/// When working with async expressions returning `Result`, `join!` will wait +/// for **all** branches complete regardless if any complete with `Err`. Use +/// [`try_join!`] to return early when `Err` is encountered. +/// +/// [`try_join!`]: macro@try_join +/// /// # Notes /// /// The supplied futures are stored inline and does not require allocating a diff --git a/tokio/src/macros/mod.rs b/tokio/src/macros/mod.rs index b5b53ccb..85420d21 100644 --- a/tokio/src/macros/mod.rs +++ b/tokio/src/macros/mod.rs @@ -27,6 +27,9 @@ cfg_macros! { #[macro_use] mod thread_local; +#[macro_use] +mod try_join; + // Includes re-exports needed to implement macros #[doc(hidden)] pub mod support; diff --git a/tokio/src/macros/try_join.rs b/tokio/src/macros/try_join.rs new file mode 100644 index 00000000..9624e7f6 --- /dev/null +++ b/tokio/src/macros/try_join.rs @@ -0,0 +1,131 @@ +/// Wait on multiple concurrent branches, returning when **all** branches +/// complete with `Ok(_)` or on the first `Err(_)`. +/// +/// The `try_join!` macro must be used inside of async functions, closures, and +/// blocks. +/// +/// Similar to [`join!`], the `try_join!` macro takes a list of async +/// expressions and evaluates them concurrently on the same task. Each async +/// expression evaluates to a future and the futures from each expression are +/// multiplexed on the current task. The `try_join!` macro returns when **all** +/// branches return with `Ok` or when the **first** branch returns with `Err`. +/// +/// [`join!`]: macro@join +/// +/// # Notes +/// +/// The supplied futures are stored inline and does not require allocating a +/// `Vec`. +/// +/// ### Runtime characteristics +/// +/// By running all async expressions on the current task, the expressions are +/// able to run **concurrently** but not in **parallel**. This means all +/// expressions are run on the same thread and if one branch blocks the thread, +/// all other expressions will be unable to continue. If parallelism is +/// required, spawn each async expression using [`tokio::spawn`] and pass the +/// join handle to `try_join!`. +/// +/// [`tokio::spawn`]: crate::spawn +/// +/// # Examples +/// +/// Basic try_join with two branches. +/// +/// ``` +/// async fn do_stuff_async() -> Result<(), &'static str> { +/// // async work +/// # Ok(()) +/// } +/// +/// async fn more_async_work() -> Result<(), &'static str> { +/// // more here +/// # Ok(()) +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let res = tokio::try_join!( +/// do_stuff_async(), +/// more_async_work()); +/// +/// match res { +/// Ok((first, second)) => { +/// // do something with the values +/// } +/// Err(err) => { +/// println!("processing failed; error = {}", err); +/// } +/// } +/// } +/// ``` +#[macro_export] +macro_rules! try_join { + (@ { + // One `_` for each branch in the `try_join!` macro. This is not used once + // normalization is complete. + ( $($count:tt)* ) + + // Normalized try_join! branches + $( ( $($skip:tt)* ) $e:expr, )* + + }) => {{ + use $crate::macros::support::{maybe_done, poll_fn, Future, Pin}; + use $crate::macros::support::Poll::{Ready, Pending}; + + // Safety: nothing must be moved out of `futures`. This is to satisfy + // the requirement of `Pin::new_unchecked` called below. + let mut futures = ( $( maybe_done($e), )* ); + + poll_fn(move |cx| { + let mut is_pending = false; + + $( + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + // Try polling + if fut.as_mut().poll(cx).is_pending() { + is_pending = true; + } else if fut.as_mut().output_mut().expect("expected completed future").is_err() { + return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap())) + } + )* + + if is_pending { + Pending + } else { + Ready(Ok(($({ + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + fut + .take_output() + .expect("expected completed future") + .ok() + .expect("expected Ok(_)") + },)*))) + } + }).await + }}; + + // ===== Normalize ===== + + (@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::try_join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*) + }; + + // ===== Entry point ===== + + ( $($e:expr),* $(,)?) => { + $crate::try_join!(@{ () } $($e,)*) + }; +} diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index ab168e81..d9b748d9 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -30,9 +30,9 @@ async fn sync_two_lit_expr_no_comma() { } #[tokio::test] -async fn sync_two_await() { - let (tx1, rx1) = oneshot::channel(); - let (tx2, rx2) = oneshot::channel(); +async fn two_await() { + let (tx1, rx1) = oneshot::channel::<&str>(); + let (tx2, rx2) = oneshot::channel::(); let mut join = task::spawn(async { tokio::join!(async { rx1.await.unwrap() }, async { rx2.await.unwrap() }) diff --git a/tokio/tests/macros_try_join.rs b/tokio/tests/macros_try_join.rs new file mode 100644 index 00000000..faa55421 --- /dev/null +++ b/tokio/tests/macros_try_join.rs @@ -0,0 +1,100 @@ +use tokio::sync::oneshot; +use tokio_test::{assert_pending, assert_ready, task}; + +#[tokio::test] +async fn sync_one_lit_expr_comma() { + let foo = tokio::try_join!(async { ok(1) },); + + assert_eq!(foo, Ok((1,))); +} + +#[tokio::test] +async fn sync_one_lit_expr_no_comma() { + let foo = tokio::try_join!(async { ok(1) }); + + assert_eq!(foo, Ok((1,))); +} + +#[tokio::test] +async fn sync_two_lit_expr_comma() { + let foo = tokio::try_join!(async { ok(1) }, async { ok(2) },); + + assert_eq!(foo, Ok((1, 2))); +} + +#[tokio::test] +async fn sync_two_lit_expr_no_comma() { + let foo = tokio::try_join!(async { ok(1) }, async { ok(2) }); + + assert_eq!(foo, Ok((1, 2))); +} + +#[tokio::test] +async fn two_await() { + let (tx1, rx1) = oneshot::channel::<&str>(); + let (tx2, rx2) = oneshot::channel::(); + + let mut join = + task::spawn(async { tokio::try_join!(async { rx1.await }, async { rx2.await }) }); + + assert_pending!(join.poll()); + + tx2.send(123).unwrap(); + assert!(join.is_woken()); + assert_pending!(join.poll()); + + tx1.send("hello").unwrap(); + assert!(join.is_woken()); + let res: Result<(&str, u32), _> = assert_ready!(join.poll()); + + assert_eq!(Ok(("hello", 123)), res); +} + +#[tokio::test] +async fn err_abort_early() { + let (tx1, rx1) = oneshot::channel::<&str>(); + let (tx2, rx2) = oneshot::channel::(); + let (_tx3, rx3) = oneshot::channel::(); + + let mut join = task::spawn(async { + tokio::try_join!(async { rx1.await }, async { rx2.await }, async { + rx3.await + }) + }); + + assert_pending!(join.poll()); + + tx2.send(123).unwrap(); + assert!(join.is_woken()); + assert_pending!(join.poll()); + + drop(tx1); + assert!(join.is_woken()); + + let res = assert_ready!(join.poll()); + + assert!(res.is_err()); +} + +#[test] +fn join_size() { + use futures::future; + use std::mem; + + let fut = async { + let ready = future::ready(ok(0i32)); + tokio::try_join!(ready) + }; + assert_eq!(mem::size_of_val(&fut), 16); + + let fut = async { + let ready1 = future::ready(ok(0i32)); + let ready2 = future::ready(ok(0i32)); + tokio::try_join!(ready1, ready2) + }; + assert_eq!(mem::size_of_val(&fut), 28); +} + +fn ok(val: T) -> Result { + Ok(val) +} -- cgit v1.2.3