summaryrefslogtreecommitdiffstats
path: root/tokio
diff options
context:
space:
mode:
authorbdonlan <bdonlan@gmail.com>2020-12-08 16:42:43 -0800
committerGitHub <noreply@github.com>2020-12-08 16:42:43 -0800
commit9706ca92a8deb69d6e29265f21424042fea966c5 (patch)
treecd77e2148b7cdf03d0fcb38e8e27cf3f7eed1ed9 /tokio
parentfc7a4b3c6e765d6d2b4ea97266cefbf466d52dc9 (diff)
time: Fix race condition in timer drop (#3229)
Dropping a timer on the millisecond that it was scheduled for, when it was on the pending list, could result in a panic previously, as we did not record the pending-list state in cached_when. Hopefully fixes: ZcashFoundation/zebra#1452
Diffstat (limited to 'tokio')
-rw-r--r--tokio/src/time/driver/entry.rs24
-rw-r--r--tokio/src/time/driver/wheel/mod.rs4
-rw-r--r--tokio/tests/time_sleep.rs72
3 files changed, 91 insertions, 9 deletions
diff --git a/tokio/src/time/driver/entry.rs b/tokio/src/time/driver/entry.rs
index e0926797..87ba0c17 100644
--- a/tokio/src/time/driver/entry.rs
+++ b/tokio/src/time/driver/entry.rs
@@ -437,6 +437,17 @@ impl TimerShared {
true_when
}
+ /// Sets the cached time-of-expiration value.
+ ///
+ /// SAFETY: Must be called with the driver lock held, and when this entry is
+ /// not in any timer wheel lists.
+ unsafe fn set_cached_when(&self, when: u64) {
+ self.driver_state
+ .0
+ .cached_when
+ .store(when, Ordering::Relaxed);
+ }
+
/// Returns the true time-of-expiration value, with relaxed memory ordering.
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
@@ -643,14 +654,13 @@ impl TimerHandle {
/// After returning Ok, the entry must be added to the pending list.
pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
match self.inner.as_ref().state.mark_pending(not_after) {
- Ok(()) => Ok(()),
+ Ok(()) => {
+ // mark this as being on the pending queue in cached_when
+ self.inner.as_ref().set_cached_when(u64::max_value());
+ Ok(())
+ }
Err(tick) => {
- self.inner
- .as_ref()
- .driver_state
- .0
- .cached_when
- .store(tick, Ordering::Relaxed);
+ self.inner.as_ref().set_cached_when(tick);
Err(tick)
}
}
diff --git a/tokio/src/time/driver/wheel/mod.rs b/tokio/src/time/driver/wheel/mod.rs
index e9df87af..164cac46 100644
--- a/tokio/src/time/driver/wheel/mod.rs
+++ b/tokio/src/time/driver/wheel/mod.rs
@@ -118,10 +118,10 @@ impl Wheel {
/// Remove `item` from the timing wheel.
pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) {
unsafe {
- if !item.as_ref().might_be_registered() {
+ let when = item.as_ref().cached_when();
+ if when == u64::max_value() {
self.pending.remove(item);
} else {
- let when = item.as_ref().cached_when();
let level = self.level_for(when);
self.levels[level].remove_entry(item);
diff --git a/tokio/tests/time_sleep.rs b/tokio/tests/time_sleep.rs
index d110ec27..20e2b1c6 100644
--- a/tokio/tests/time_sleep.rs
+++ b/tokio/tests/time_sleep.rs
@@ -308,3 +308,75 @@ async fn no_out_of_bounds_close_to_max() {
fn ms(n: u64) -> Duration {
Duration::from_millis(n)
}
+
+#[tokio::test]
+async fn drop_after_reschedule_at_new_scheduled_time() {
+ use futures::poll;
+
+ tokio::time::pause();
+
+ let start = tokio::time::Instant::now();
+
+ let mut a = tokio::time::sleep(Duration::from_millis(5));
+ let mut b = tokio::time::sleep(Duration::from_millis(5));
+ let mut c = tokio::time::sleep(Duration::from_millis(10));
+
+ let _ = poll!(&mut a);
+ let _ = poll!(&mut b);
+ let _ = poll!(&mut c);
+
+ b.reset(start + Duration::from_millis(10));
+ a.await;
+
+ drop(b);
+}
+
+#[tokio::test]
+async fn drop_from_wake() {
+ use std::future::Future;
+ use std::sync::atomic::{AtomicBool, Ordering};
+ use std::sync::{Arc, Mutex};
+ use std::task::Context;
+
+ let panicked = Arc::new(AtomicBool::new(false));
+ let list: Arc<Mutex<Vec<tokio::time::Sleep>>> = Arc::new(Mutex::new(Vec::new()));
+
+ let arc_wake = Arc::new(DropWaker(panicked.clone(), list.clone()));
+ let arc_wake = futures::task::waker(arc_wake);
+
+ tokio::time::pause();
+
+ let mut lock = list.lock().unwrap();
+
+ for _ in 0..100 {
+ let mut timer = tokio::time::sleep(Duration::from_millis(10));
+
+ let _ = std::pin::Pin::new(&mut timer).poll(&mut Context::from_waker(&arc_wake));
+
+ lock.push(timer);
+ }
+
+ drop(lock);
+
+ tokio::time::sleep(Duration::from_millis(11)).await;
+
+ assert!(
+ !panicked.load(Ordering::SeqCst),
+ "paniced when dropping timers"
+ );
+
+ #[derive(Clone)]
+ struct DropWaker(Arc<AtomicBool>, Arc<Mutex<Vec<tokio::time::Sleep>>>);
+
+ impl futures::task::ArcWake for DropWaker {
+ fn wake_by_ref(arc_self: &Arc<Self>) {
+ let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
+ *arc_self.1.lock().expect("panic in lock") = Vec::new()
+ }));
+
+ if result.is_err() {
+ arc_self.0.store(true, Ordering::SeqCst);
+ }
+ }
+ }
+}