summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorCarl Lerche <me@carllerche.com>2020-01-29 12:00:40 -0800
committerGitHub <noreply@github.com>2020-01-29 12:00:40 -0800
commit9d6b99494b72e79b4afba5073a9ebef5bbbeca8a (patch)
tree8804efeaec910a7881c63ff7723340d8993f23de
parent560d0fa548314e223601ed83429daf237d72afbd (diff)
rt: add `Runtime::shutdown_timeout` (#2186)
Provides an API for forcing a runtime to shutdown even if there are still running tasks.
-rw-r--r--tokio/src/park/thread.rs10
-rw-r--r--tokio/src/runtime/blocking/mod.rs4
-rw-r--r--tokio/src/runtime/blocking/pool.rs19
-rw-r--r--tokio/src/runtime/blocking/shutdown.rs14
-rw-r--r--tokio/src/runtime/enter.rs39
-rw-r--r--tokio/src/runtime/mod.rs43
-rw-r--r--tokio/tests/rt_common.rs17
7 files changed, 132 insertions, 14 deletions
diff --git a/tokio/src/park/thread.rs b/tokio/src/park/thread.rs
index 853d4d8a..12ef9717 100644
--- a/tokio/src/park/thread.rs
+++ b/tokio/src/park/thread.rs
@@ -10,13 +10,7 @@ pub(crate) struct ParkThread {
inner: Arc<Inner>,
}
-/// Error returned by `ParkThread`
-///
-/// This currently is never returned, but might at some point in the future.
-#[derive(Debug)]
-pub(crate) struct ParkError {
- _p: (),
-}
+pub(crate) type ParkError = ();
/// Unblocks a thread that was blocked by `ParkThread`.
#[derive(Clone, Debug)]
@@ -240,7 +234,7 @@ cfg_blocking_impl! {
F: FnOnce(&ParkThread) -> R,
{
CURRENT_PARKER.try_with(|inner| f(inner))
- .map_err(|_| ParkError { _p: () })
+ .map_err(|_| ())
}
}
diff --git a/tokio/src/runtime/blocking/mod.rs b/tokio/src/runtime/blocking/mod.rs
index be56e8f8..ff400b33 100644
--- a/tokio/src/runtime/blocking/mod.rs
+++ b/tokio/src/runtime/blocking/mod.rs
@@ -21,6 +21,7 @@ cfg_blocking_impl! {
cfg_not_blocking_impl! {
use crate::runtime::Builder;
+ use std::time::Duration;
#[derive(Debug, Clone)]
pub(crate) struct BlockingPool {}
@@ -35,5 +36,8 @@ cfg_not_blocking_impl! {
pub(crate) fn spawner(&self) -> &BlockingPool {
self
}
+
+ pub(crate) fn shutdown(&mut self, _duration: Option<Duration>) {
+ }
}
}
diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs
index 2a618ff5..1784312d 100644
--- a/tokio/src/runtime/blocking/pool.rs
+++ b/tokio/src/runtime/blocking/pool.rs
@@ -101,19 +101,30 @@ impl BlockingPool {
pub(crate) fn spawner(&self) -> &Spawner {
&self.spawner
}
-}
-impl Drop for BlockingPool {
- fn drop(&mut self) {
+ pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
let mut shared = self.spawner.inner.shared.lock().unwrap();
+ // The function can be called multiple times. First, by explicitly
+ // calling `shutdown` then by the drop handler calling `shutdown`. This
+ // prevents shutting down twice.
+ if shared.shutdown {
+ return;
+ }
+
shared.shutdown = true;
shared.shutdown_tx = None;
self.spawner.inner.condvar.notify_all();
drop(shared);
- self.shutdown_rx.wait();
+ self.shutdown_rx.wait(timeout);
+ }
+}
+
+impl Drop for BlockingPool {
+ fn drop(&mut self) {
+ self.shutdown(None);
}
}
diff --git a/tokio/src/runtime/blocking/shutdown.rs b/tokio/src/runtime/blocking/shutdown.rs
index a7b4fc5e..5ee8af0f 100644
--- a/tokio/src/runtime/blocking/shutdown.rs
+++ b/tokio/src/runtime/blocking/shutdown.rs
@@ -6,6 +6,8 @@
use crate::loom::sync::Arc;
use crate::sync::oneshot;
+use std::time::Duration;
+
#[derive(Debug, Clone)]
pub(super) struct Sender {
tx: Arc<oneshot::Sender<()>>,
@@ -26,7 +28,11 @@ pub(super) fn channel() -> (Sender, Receiver) {
impl Receiver {
/// Blocks the current thread until all `Sender` handles drop.
- pub(crate) fn wait(&mut self) {
+ ///
+ /// If `timeout` is `Some`, the thread is blocked for **at most** `timeout`
+ /// duration. If `timeout` is `None`, then the thread is blocked until the
+ /// shutdown signal is received.
+ pub(crate) fn wait(&mut self, timeout: Option<Duration>) {
use crate::runtime::enter::{enter, try_enter};
let mut e = if std::thread::panicking() {
@@ -43,6 +49,10 @@ impl Receiver {
// If blocking fails to wait, this indicates a problem parking the
// current thread (usually, shutting down a runtime stored in a
// thread-local).
- let _ = e.block_on(&mut self.rx);
+ if let Some(timeout) = timeout {
+ let _ = e.block_on_timeout(&mut self.rx, timeout);
+ } else {
+ let _ = e.block_on(&mut self.rx);
+ }
}
}
diff --git a/tokio/src/runtime/enter.rs b/tokio/src/runtime/enter.rs
index bada8e7b..64648b71 100644
--- a/tokio/src/runtime/enter.rs
+++ b/tokio/src/runtime/enter.rs
@@ -75,6 +75,7 @@ pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R {
cfg_blocking_impl! {
use crate::park::ParkError;
+ use std::time::Duration;
impl Enter {
/// Blocks the thread on the specified future, returning the value with
@@ -104,6 +105,44 @@ cfg_blocking_impl! {
park.park()?;
}
}
+
+ /// Blocks the thread on the specified future for **at most** `timeout`
+ ///
+ /// If the future completes before `timeout`, the result is returned. If
+ /// `timeout` elapses, then `Err` is returned.
+ pub(crate) fn block_on_timeout<F>(&mut self, mut f: F, timeout: Duration) -> Result<F::Output, ParkError>
+ where
+ F: std::future::Future,
+ {
+ use crate::park::{CachedParkThread, Park};
+ use std::pin::Pin;
+ use std::task::Context;
+ use std::task::Poll::Ready;
+ use std::time::Instant;
+
+ let mut park = CachedParkThread::new();
+ let waker = park.get_unpark()?.into_waker();
+ let mut cx = Context::from_waker(&waker);
+
+ // `block_on` takes ownership of `f`. Once it is pinned here, the original `f` binding can
+ // no longer be accessed, making the pinning safe.
+ let mut f = unsafe { Pin::new_unchecked(&mut f) };
+ let when = Instant::now() + timeout;
+
+ loop {
+ if let Ready(v) = f.as_mut().poll(&mut cx) {
+ return Ok(v);
+ }
+
+ let now = Instant::now();
+
+ if now >= when {
+ return Err(());
+ }
+
+ park.park_timeout(when - now)?;
+ }
+ }
}
}
diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs
index 18e43e87..3c56e138 100644
--- a/tokio/src/runtime/mod.rs
+++ b/tokio/src/runtime/mod.rs
@@ -235,6 +235,7 @@ cfg_rt_core! {
}
use std::future::Future;
+use std::time::Duration;
/// The Tokio runtime.
///
@@ -441,4 +442,46 @@ impl Runtime {
pub fn handle(&self) -> &Handle {
&self.handle
}
+
+ /// Shutdown the runtime, waiting for at most `duration` for all spawned
+ /// task to shutdown.
+ ///
+ /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to
+ /// shutdown in a timely fashion. However, dropping a `Runtime` will wait
+ /// indefinitely for all tasks to terminate, and there are cases where a long
+ /// blocking task has been spawned which can block dropping `Runtime`.
+ ///
+ /// In this case, calling `shutdown_timeout` with an explicit wait timeout
+ /// can work. The `shutdown_timeout` will signal all tasks to shutdown and
+ /// will wait for at most `duration` for all spawned tasks to terminate. If
+ /// `timeout` elapses before all tasks are dropped, the function returns and
+ /// outstanding tasks are potentially leaked.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::runtime::Runtime;
+ /// use tokio::task;
+ ///
+ /// use std::thread;
+ /// use std::time::Duration;
+ ///
+ /// fn main() {
+ /// let mut runtime = Runtime::new().unwrap();
+ ///
+ /// runtime.block_on(async move {
+ /// task::spawn_blocking(move || {
+ /// thread::sleep(Duration::from_secs(10_000));
+ /// });
+ /// });
+ ///
+ /// runtime.shutdown_timeout(Duration::from_millis(100));
+ /// }
+ /// ```
+ pub fn shutdown_timeout(self, duration: Duration) {
+ let Runtime {
+ mut blocking_pool, ..
+ } = self;
+ blocking_pool.shutdown(Some(duration));
+ }
}
diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs
index 31edd10a..64dd3680 100644
--- a/tokio/tests/rt_common.rs
+++ b/tokio/tests/rt_common.rs
@@ -710,6 +710,23 @@ rt_test! {
}
#[test]
+ fn shutdown_timeout() {
+ let (tx, rx) = oneshot::channel();
+ let mut runtime = rt();
+
+ runtime.block_on(async move {
+ task::spawn_blocking(move || {
+ tx.send(()).unwrap();
+ thread::sleep(Duration::from_secs(10_000));
+ });
+
+ rx.await.unwrap();
+ });
+
+ runtime.shutdown_timeout(Duration::from_millis(100));
+ }
+
+ #[test]
fn runtime_in_thread_local() {
use std::cell::RefCell;
use std::thread;