diff options
author | Eliza Weisman <eliza@buoyant.io> | 2020-01-06 14:44:30 -0800 |
---|---|---|
committer | Carl Lerche <me@carllerche.com> | 2020-01-06 14:44:30 -0800 |
commit | 798e86821f6e06fba552bd670c5887ce3b6ff698 (patch) | |
tree | 9dfff891883eab3c27f986002e9389ae1fd65227 /tokio/src/task | |
parent | 0193df3a593cb69d23414109118784de2948024c (diff) |
task: add ways to run a `LocalSet` from within a rt context (#1971)
Currently, the only way to run a `tokio::task::LocalSet` is to call its
`block_on` method with a `&mut Runtime`, like
```rust
let mut rt = tokio::runtime::Runtime::new();
let local = tokio::task::LocalSet::new();
local.block_on(&mut rt, async {
// whatever...
});
```
Unfortunately, this means that `LocalSet` doesn't work with the
`#[tokio::main]` and `#[tokio::test]` macros, since the `main`
function is _already_ inside of a call to `block_on`.
**Solution**
This branch adds a `LocalSet::run` method, which takes a future and
returns a new future that runs that future on the `LocalSet`. This
is analogous to `LocalSet::block_on`, except that it can be called in
an async context.
Additionally, this branch implements `Future` for `LocalSet`. Awaiting
a `LocalSet` will run all spawned local futures until they complete.
This allows code like
```rust
#[tokio::main]
async fn main() {
let local = tokio::task::LocalSet::new();
local.spawn_local(async {
// ...
});
local.spawn_local(async {
// ...
tokio::task::spawn_local(...);
// ...
});
local.await;
}
```
The `LocalSet` docs have been updated to show the usage with
`#[tokio::main]` rather than with manually created runtimes, where
applicable.
Closes #1906
Closes #1908
Fixes #2057
Diffstat (limited to 'tokio/src/task')
-rw-r--r-- | tokio/src/task/local.rs | 612 |
1 files changed, 173 insertions, 439 deletions
diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index d43187a8..ef49eebc 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -22,16 +22,14 @@ cfg_rt_util! { /// For example, the following code will not compile: /// /// ```rust,compile_fail - /// # use tokio::runtime::Runtime; /// use std::rc::Rc; /// - /// // `Rc` does not implement `Send`, and thus may not be sent between - /// // threads safely. - /// let unsend_data = Rc::new("my unsend data..."); + /// #[tokio::main] + /// async fn main() { + /// // `Rc` does not implement `Send`, and thus may not be sent between + /// // threads safely. + /// let unsend_data = Rc::new("my unsend data..."); /// - /// let mut rt = Runtime::new().unwrap(); - /// - /// rt.block_on(async move { /// let unsend_data = unsend_data.clone(); /// // Because the `async` block here moves `unsend_data`, the future is `!Send`. /// // Since `tokio::spawn` requires the spawned future to implement `Send`, this @@ -40,7 +38,7 @@ cfg_rt_util! { /// println!("{}", unsend_data); /// // ... /// }).await.unwrap(); - /// }); + /// } /// ``` /// In order to spawn `!Send` futures, we can use a local task set to /// schedule them on the thread calling [`Runtime::block_on`]. When running @@ -48,26 +46,60 @@ cfg_rt_util! { /// spawn `!Send` futures. For example: /// /// ```rust - /// # use tokio::runtime::Runtime; /// use std::rc::Rc; /// use tokio::task; /// - /// let unsend_data = Rc::new("my unsend data..."); + /// #[tokio::main] + /// async fn main() { + /// let unsend_data = Rc::new("my unsend data..."); /// - /// let mut rt = Runtime::new().unwrap(); - /// // Construct a local task set that can run `!Send` futures. - /// let local = task::LocalSet::new(); + /// // Construct a local task set that can run `!Send` futures. + /// let local = task::LocalSet::new(); /// - /// // Run the local task group. - /// local.block_on(&mut rt, async move { - /// let unsend_data = unsend_data.clone(); - /// // `spawn_local` ensures that the future is spawned on the local - /// // task group. - /// task::spawn_local(async move { - /// println!("{}", unsend_data); + /// // Run the local task set. + /// local.run_until(async move { + /// let unsend_data = unsend_data.clone(); + /// // `spawn_local` ensures that the future is spawned on the local + /// // task set. + /// task::spawn_local(async move { + /// println!("{}", unsend_data); + /// // ... + /// }).await.unwrap(); + /// }).await; + /// } + /// ``` + /// + /// ## Awaiting a `LocalSet` + /// + /// Additionally, a `LocalSet` itself implements `Future`, completing when + /// *all* tasks spawned on the `LocalSet` complete. This can be used to run + /// several futures on a `LocalSet` and drive the whole set until they + /// complete. For example, + /// + /// ```rust + /// use tokio::{task, time}; + /// use std::rc::Rc; + /// + /// #[tokio::main] + /// async fn main() { + /// let unsend_data = Rc::new("world"); + /// let local = task::LocalSet::new(); + /// + /// let unsend_data2 = unsend_data.clone(); + /// local.spawn_local(async move { /// // ... - /// }).await.unwrap(); - /// }); + /// println!("hello {}", unsend_data2) + /// }); + /// + /// local.spawn_local(async move { + /// time::delay_for(time::Duration::from_millis(100)).await; + /// println!("goodbye {}", unsend_data) + /// }); + /// + /// // ... + /// + /// local.await; + /// } /// ``` /// /// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html @@ -92,6 +124,7 @@ struct Scheduler { } pin_project! { + #[derive(Debug)] struct LocalFuture<F> { scheduler: Rc<Scheduler>, #[pin] @@ -116,23 +149,24 @@ cfg_rt_util! { /// # Examples /// /// ```rust - /// # use tokio::runtime::Runtime; /// use std::rc::Rc; /// use tokio::task; /// - /// let unsend_data = Rc::new("my unsend data..."); + /// #[tokio::main] + /// async fn main() { + /// let unsend_data = Rc::new("my unsend data..."); /// - /// let mut rt = Runtime::new().unwrap(); - /// let local = task::LocalSet::new(); + /// let local = task::LocalSet::new(); /// - /// // Run the local task set. - /// local.block_on(&mut rt, async move { - /// let unsend_data = unsend_data.clone(); - /// task::spawn_local(async move { - /// println!("{}", unsend_data); - /// // ... - /// }).await.unwrap(); - /// }); + /// // Run the local task set. + /// local.run_until(async move { + /// let unsend_data = unsend_data.clone(); + /// task::spawn_local(async move { + /// println!("{}", unsend_data); + /// // ... + /// }).await.unwrap(); + /// }).await; + /// } /// ``` pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output> where @@ -173,34 +207,35 @@ impl LocalSet { /// This task is guaranteed to be run on the current thread. /// /// Unlike the free function [`spawn_local`], this method may be used to - /// spawn_local local tasks when the task set is _not_ running. For example: + /// spawn local tasks when the task set is _not_ running. For example: /// ```rust - /// # use tokio::runtime::Runtime; /// use tokio::task; /// - /// let mut rt = Runtime::new().unwrap(); - /// let local = task::LocalSet::new(); + /// #[tokio::main] + /// async fn main() { + /// let local = task::LocalSet::new(); /// - /// // Spawn a future on the local set. This future will be run when - /// // we call `block_on` to drive the task set. - /// local.spawn_local(async { - /// // ... - /// }); + /// // Spawn a future on the local set. This future will be run when + /// // we call `run_until` to drive the task set. + /// local.spawn_local(async { + /// // ... + /// }); /// - /// // Run the local task set. - /// local.block_on(&mut rt, async move { - /// // ... - /// }); + /// // Run the local task set. + /// local.run_until(async move { + /// // ... + /// }).await; /// - /// // When `block_on` finishes, we can spawn_local _more_ futures, which will - /// // run in subsequent calls to `block_on`. - /// local.spawn_local(async { - /// // ... - /// }); + /// // When `run` finishes, we can spawn _more_ futures, which will + /// // run in subsequent calls to `run_until`. + /// local.spawn_local(async { + /// // ... + /// }); /// - /// local.block_on(&mut rt, async move { - /// // ... - /// }); + /// local.run_until(async move { + /// // ... + /// }).await; + /// } /// ``` /// [`spawn_local`]: fn.spawn_local.html pub fn spawn_local<F>(&self, future: F) -> JoinHandle<F::Output> @@ -283,9 +318,68 @@ impl LocalSet { where F: Future, { + rt.block_on(self.run_until(future)) + } + + /// Run a future to completion on the local set, returning its output. + /// + /// This returns a future that runs the given future with a local set, + /// allowing it to call [`spawn_local`] to spawn additional `!Send` futures. + /// Any local futures spawned on the local set will be driven in the + /// background until the future passed to `run_until` completes. When the future + /// passed to `run` finishes, any local futures which have not completed + /// will remain on the local set, and will be driven on subsequent calls to + /// `run_until` or when [awaiting the local set] itself. + /// + /// # Examples + /// + /// ```rust + /// use tokio::task; + /// + /// #[tokio::main] + /// async fn main() { + /// task::LocalSet::new().run_until(async { + /// task::spawn_local(async move { + /// // ... + /// }).await.unwrap(); + /// // ... + /// }).await; + /// } + /// ``` + /// + /// [`spawn_local`]: fn.spawn_local.html + /// [awaiting the local set]: #awaiting-a-localset + pub async fn run_until<F>(&self, future: F) -> F::Output + where + F: Future, + { let scheduler = self.scheduler.clone(); - self.scheduler - .with(move || rt.block_on(LocalFuture { scheduler, future })) + let future = LocalFuture { scheduler, future }; + future.await + } +} + +impl Future for LocalSet { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let scheduler = self.as_ref().scheduler.clone(); + scheduler.waker.register_by_ref(cx.waker()); + + if scheduler.with(|| scheduler.tick()) { + // If `tick` returns true, we need to notify the local future again: + // there are still tasks remaining in the run queue. + cx.waker().wake_by_ref(); + Poll::Pending + } else if scheduler.is_empty() { + // If the scheduler has no remaining futures, we're done! + Poll::Ready(()) + } else { + // There are still futures in the local set, but we've polled all the + // futures in the run queue. Therefore, we can just return Pending + // since the remaining futures will be woken from somewhere else. + Poll::Pending + } } } @@ -295,6 +389,8 @@ impl Default for LocalSet { } } +// === impl LocalFuture === + impl<F: Future> Future for LocalFuture<F> { type Output = F::Output; @@ -303,18 +399,19 @@ impl<F: Future> Future for LocalFuture<F> { let scheduler = this.scheduler; let mut future = this.future; scheduler.waker.register_by_ref(cx.waker()); + scheduler.with(|| { + if let Poll::Ready(output) = future.as_mut().poll(cx) { + return Poll::Ready(output); + } - if let Poll::Ready(output) = future.as_mut().poll(cx) { - return Poll::Ready(output); - } - - if scheduler.tick() { - // If `tick` returns true, we need to notify the local future again: - // there are still tasks remaining in the run queue. - cx.waker().wake_by_ref(); - } + if scheduler.tick() { + // If `tick` returns true, we need to notify the local future again: + // there are still tasks remaining in the run queue. + cx.waker().wake_by_ref(); + } - Poll::Pending + Poll::Pending + }) } } @@ -424,6 +521,15 @@ impl Scheduler { true } + + fn is_empty(&self) -> bool { + unsafe { + // safety: this method may not be called from threads other than the + // thread that owns the `Queues`. since `Scheduler` is not `Send` or + // `Sync`, that shouldn't happen. + !self.queues.has_tasks_remaining() + } + } } impl Drop for Scheduler { @@ -450,375 +556,3 @@ impl Drop for Scheduler { } } } - -#[cfg(all(test, not(loom)))] -mod tests { - use super::*; - use crate::{ - runtime, - sync::{mpsc, oneshot}, - task, time, - }; - use std::time::Duration; - - #[test] - fn local_current_thread() { - let mut rt = runtime::Builder::new().basic_scheduler().build().unwrap(); - LocalSet::new().block_on(&mut rt, async { - spawn_local(async {}).await.unwrap(); - }); - } - - #[test] - fn local_threadpool() { - thread_local! { - static ON_RT_THREAD: Cell<bool> = Cell::new(false); - } - - ON_RT_THREAD.with(|cell| cell.set(true)); - - let mut rt = runtime::Runtime::new().unwrap(); - LocalSet::new().block_on(&mut rt, async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - spawn_local(async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - }) - .await - .unwrap(); - }); - } - - #[test] - fn local_threadpool_timer() { - // This test ensures that runtime services like the timer are properly - // set for the local task set. - thread_local! { - static ON_RT_THREAD: Cell<bool> = Cell::new(false); - } - - ON_RT_THREAD.with(|cell| cell.set(true)); - - let mut rt = runtime::Builder::new() - .threaded_scheduler() - .enable_all() - .build() - .unwrap(); - LocalSet::new().block_on(&mut rt, async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - let join = spawn_local(async move { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - crate::time::delay_for(Duration::from_millis(10)).await; - assert!(ON_RT_THREAD.with(|cell| cell.get())); - }); - join.await.unwrap(); - }); - } - - #[test] - // This will panic, since the thread that calls `block_on` cannot use - // in-place blocking inside of `block_on`. - #[should_panic] - fn local_threadpool_blocking_in_place() { - thread_local! { - static ON_RT_THREAD: Cell<bool> = Cell::new(false); - } - - ON_RT_THREAD.with(|cell| cell.set(true)); - - let mut rt = runtime::Builder::new() - .threaded_scheduler() - .enable_all() - .build() - .unwrap(); - LocalSet::new().block_on(&mut rt, async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - let join = spawn_local(async move { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - task::block_in_place(|| {}); - assert!(ON_RT_THREAD.with(|cell| cell.get())); - }); - join.await.unwrap(); - }); - } - - #[test] - fn local_threadpool_blocking_run() { - thread_local! { - static ON_RT_THREAD: Cell<bool> = Cell::new(false); - } - - ON_RT_THREAD.with(|cell| cell.set(true)); - - let mut rt = runtime::Builder::new() - .threaded_scheduler() - .enable_all() - .build() - .unwrap(); - LocalSet::new().block_on(&mut rt, async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - let join = spawn_local(async move { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - task::spawn_blocking(|| { - assert!( - !ON_RT_THREAD.with(|cell| cell.get()), - "blocking must not run on the local task set's thread" - ); - }) - .await - .unwrap(); - assert!(ON_RT_THREAD.with(|cell| cell.get())); - }); - join.await.unwrap(); - }); - } - - #[test] - fn all_spawns_are_local() { - use futures::future; - thread_local! { - static ON_RT_THREAD: Cell<bool> = Cell::new(false); - } - - ON_RT_THREAD.with(|cell| cell.set(true)); - - let mut rt = runtime::Builder::new() - .threaded_scheduler() - .build() - .unwrap(); - LocalSet::new().block_on(&mut rt, async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - let handles = (0..128) - .map(|_| { - spawn_local(async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - }) - }) - .collect::<Vec<_>>(); - for joined in future::join_all(handles).await { - joined.unwrap(); - } - }) - } - - #[test] - fn nested_spawn_is_local() { - thread_local! { - static ON_RT_THREAD: Cell<bool> = Cell::new(false); - } - - ON_RT_THREAD.with(|cell| cell.set(true)); - - let mut rt = runtime::Builder::new() - .threaded_scheduler() - .build() - .unwrap(); - LocalSet::new().block_on(&mut rt, async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - spawn_local(async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - spawn_local(async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - spawn_local(async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - spawn_local(async { - assert!(ON_RT_THREAD.with(|cell| cell.get())); - }) - .await - .unwrap(); - }) - .await - .unwrap(); - }) - .await - .unwrap(); - }) - .await - .unwrap(); - }) - } - #[test] - fn join_local_future_elsewhere() { - thread_local! { - static ON_RT_THREAD: Cell<bool> = Cell::new(false); - } - - ON_RT_THREAD.with(|cell| cell.set(true)); - - let mut rt = runtime::Builder::new() - .threaded_scheduler() - .build() - .unwrap(); - let local = LocalSet::new(); - local.block_on(&mut rt, async move { - let (tx, rx) = crate::sync::oneshot::channel(); - let join = spawn_local(async move { - println!("hello world running..."); - assert!( - ON_RT_THREAD.with(|cell| cell.get()), - "local task must run on local thread, no matter where it is awaited" - ); - rx.await.unwrap(); - - println!("hello world task done"); - "hello world" - }); - let join2 = task::spawn(async move { - assert!( - !ON_RT_THREAD.with(|cell| cell.get()), - "spawned task should be on a worker" - ); - - tx.send(()).expect("task shouldn't have ended yet"); - println!("waking up hello world..."); - - join.await.expect("task should complete successfully"); - - println!("hello world task joined"); - }); - join2.await.unwrap() - }); - } - #[test] - fn drop_cancels_tasks() { - // This test reproduces issue #1842 - let mut rt = runtime::Builder::new() - .enable_time() - .basic_scheduler() - .build() - .unwrap(); - - let (started_tx, started_rx) = oneshot::channel(); - - let local = LocalSet::new(); - local.spawn_local(async move { - started_tx.send(()).unwrap(); - loop { - time::delay_for(Duration::from_secs(3600)).await; - } - }); - - local.block_on(&mut rt, async { - started_rx.await.unwrap(); - }); - drop(local); - drop(rt); - } - - #[test] - fn drop_cancels_remote_tasks() { - // This test reproduces issue #1885. - use std::sync::mpsc::RecvTimeoutError; - - let (done_tx, done_rx) = std::sync::mpsc::channel(); - let thread = std::thread::spawn(move || { - let (tx, mut rx) = crate::sync::mpsc::channel::<()>(1024); - - let mut rt = runtime::Builder::new() - .enable_time() - .basic_scheduler() - .build() - .expect("building runtime should succeed"); - - let local = LocalSet::new(); - local.spawn_local(async move { while let Some(_) = rx.recv().await {} }); - local.block_on(&mut rt, async { - crate::time::delay_for(Duration::from_millis(1)).await; - }); - - drop(tx); - - // This enters an infinite loop if the remote notified tasks are not - // properly cancelled. - drop(local); - - // Send a message on the channel so that the test thread can - // determine if we have entered an infinite loop: - done_tx.send(()).unwrap(); - }); - - // Since the failure mode of this test is an infinite loop, rather than - // something we can easily make assertions about, we'll run it in a - // thread. When the test thread finishes, it will send a message on a - // channel to this thread. We'll wait for that message with a fairly - // generous timeout, and if we don't recieve it, we assume the test - // thread has hung. - // - // Note that it should definitely complete in under a minute, but just - // in case CI is slow, we'll give it a long timeout. - match done_rx.recv_timeout(Duration::from_secs(60)) { - Err(RecvTimeoutError::Timeout) => panic!( - "test did not complete within 60 seconds, \ - we have (probably) entered an infinite loop!" - ), - // Did the test thread panic? We'll find out for sure when we `join` - // with it. - Err(RecvTimeoutError::Disconnected) => { - println!("done_rx dropped, did the test thread panic?"); - } - // Test completed successfully! - Ok(()) => {} - } - - thread.join().expect("test thread should not panic!") - } - - #[test] - fn local_tasks_are_polled_after_tick() { - // Reproduces issues #1899 and #1900 - use std::sync::atomic::{AtomicUsize, Ordering::SeqCst}; - - static RX1: AtomicUsize = AtomicUsize::new(0); - static RX2: AtomicUsize = AtomicUsize::new(0); - static EXPECTED: usize = 500; - - let (tx, mut rx) = mpsc::unbounded_channel(); - - let mut rt = runtime::Builder::new() - .basic_scheduler() - .enable_all() - .build() - .unwrap(); - - let local = LocalSet::new(); - - local.block_on(&mut rt, async { - let task2 = task::spawn(async move { - // Wait a bit - time::delay_for(Duration::from_millis(100)).await; - - let mut oneshots = Vec::with_capacity(EXPECTED); - - // Send values - for _ in 0..EXPECTED { - let (oneshot_tx, oneshot_rx) = oneshot::channel(); - oneshots.push(oneshot_tx); - tx.send(oneshot_rx).unwrap(); - } - - time::delay_for(Duration::from_millis(100)).await; - - for tx in oneshots.drain(..) { - tx.send(()).unwrap(); - } - - time::delay_for(Duration::from_millis(300)).await; - let rx1 = RX1.load(SeqCst); - let rx2 = RX2.load(SeqCst); - println!("EXPECT = {}; RX1 = {}; RX2 = {}", EXPECTED, rx1, rx2); - assert_eq!(EXPECTED, rx1); - assert_eq!(EXPECTED, rx2); - }); - - while let Some(oneshot) = rx.recv().await { - RX1.fetch_add(1, SeqCst); - - task::spawn_local(async move { - oneshot.await.unwrap(); - RX2.fetch_add(1, SeqCst); - }); - } - - task2.await.unwrap(); - }); - } -} |