From d0ebb4154748166a4ba07baa4b424a1c45efd219 Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Mon, 16 Nov 2020 22:49:35 +0200 Subject: sync: add `Notify::notify_waiters` (#3098) This PR makes `Notify::notify_waiters` public. The method already exists, but it changes the way `notify_waiters`, is used. Previously in order for the consumer to register interest, in a notification triggered by `notify_waiters`, the `Notified` future had to be polled. This introduced friction when using the api as the future had to be pinned before polled. This change introduces a counter that tracks how many times `notified_waiters` has been called. Upon creation of the future the number of times is loaded. When first polled the future compares this number with the count state of the `Notify` type. This avoids the need for registering the waiter upfront. Fixes: #3066 --- tokio/src/sync/mpsc/chan.rs | 19 +--- tokio/src/sync/notify.rs | 207 ++++++++++++++++++++++++++++-------- tokio/src/sync/tests/loom_notify.rs | 21 ++++ tokio/src/sync/watch.rs | 17 +-- tokio/tests/sync_notify.rs | 34 ++++++ 5 files changed, 218 insertions(+), 80 deletions(-) diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index c78fb501..a40f5c3d 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -148,25 +148,10 @@ impl Tx { } pub(crate) async fn closed(&self) { - use std::future::Future; - use std::pin::Pin; - use std::task::Poll; - // In order to avoid a race condition, we first request a notification, - // **then** check the current value's version. If a new version exists, - // the notification request is dropped. Requesting the notification - // requires polling the future once. + // **then** check whether the semaphore is closed. If the semaphore is + // closed the notification request is dropped. let notified = self.inner.notify_rx_closed.notified(); - pin!(notified); - - // Polling the future once is guaranteed to return `Pending` as `watch` - // only notifies using `notify_waiters`. - crate::future::poll_fn(|cx| { - let res = Pin::new(&mut notified).poll(cx); - assert!(!res.is_ready()); - Poll::Ready(()) - }) - .await; if self.inner.semaphore.is_closed() { return; diff --git a/tokio/src/sync/notify.rs b/tokio/src/sync/notify.rs index 922f1095..f39f92f8 100644 --- a/tokio/src/sync/notify.rs +++ b/tokio/src/sync/notify.rs @@ -5,7 +5,7 @@ // triggers this warning but it is safe to ignore in this case. #![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] -use crate::loom::sync::atomic::AtomicU8; +use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Mutex; use crate::util::linked_list::{self, LinkedList}; @@ -109,7 +109,11 @@ type WaitList = LinkedList::Target>; /// [`Semaphore`]: crate::sync::Semaphore #[derive(Debug)] pub struct Notify { - state: AtomicU8, + // This uses 2 bits to store one of `EMPTY`, + // `WAITING` or `NOTIFIED`. The rest of the bits + // are used to store the number of times `notify_waiters` + // was called. + state: AtomicUsize, waiters: Mutex, } @@ -154,19 +158,39 @@ unsafe impl<'a> Sync for Notified<'a> {} #[derive(Debug)] enum State { - Init, + Init(usize), Waiting, Done, } +const NOTIFY_WAITERS_SHIFT: usize = 2; +const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1; +const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK; + /// Initial "idle" state -const EMPTY: u8 = 0; +const EMPTY: usize = 0; /// One or more threads are currently waiting to be notified. -const WAITING: u8 = 1; +const WAITING: usize = 1; /// Pending notification -const NOTIFIED: u8 = 2; +const NOTIFIED: usize = 2; + +fn set_state(data: usize, state: usize) -> usize { + (data & NOTIFY_WAITERS_CALLS_MASK) | (state & STATE_MASK) +} + +fn get_state(data: usize) -> usize { + data & STATE_MASK +} + +fn get_num_notify_waiters_calls(data: usize) -> usize { + (data & NOTIFY_WAITERS_CALLS_MASK) >> NOTIFY_WAITERS_SHIFT +} + +fn inc_num_notify_waiters_calls(data: usize) -> usize { + data + (1 << NOTIFY_WAITERS_SHIFT) +} impl Notify { /// Create a new `Notify`, initialized without a permit. @@ -180,7 +204,7 @@ impl Notify { /// ``` pub fn new() -> Notify { Notify { - state: AtomicU8::new(0), + state: AtomicUsize::new(0), waiters: Mutex::new(LinkedList::new()), } } @@ -198,7 +222,7 @@ impl Notify { #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] pub const fn const_new() -> Notify { Notify { - state: AtomicU8::new(0), + state: AtomicUsize::new(0), waiters: Mutex::const_new(LinkedList::new()), } } @@ -239,9 +263,12 @@ impl Notify { /// } /// ``` pub fn notified(&self) -> Notified<'_> { + // we load the number of times notify_waiters + // was called and store that in our initial state + let state = self.state.load(SeqCst); Notified { notify: self, - state: State::Init, + state: State::Init(state >> NOTIFY_WAITERS_SHIFT), waiter: UnsafeCell::new(Waiter { pointers: linked_list::Pointers::new(), waker: None, @@ -290,11 +317,12 @@ impl Notify { let mut curr = self.state.load(SeqCst); // If the state is `EMPTY`, transition to `NOTIFIED` and return. - while let EMPTY | NOTIFIED = curr { + while let EMPTY | NOTIFIED = get_state(curr) { // The compare-exchange from `NOTIFIED` -> `NOTIFIED` is intended. A // happens-before synchronization must happen between this atomic // operation and a task calling `notified().await`. - let res = self.state.compare_exchange(curr, NOTIFIED, SeqCst, SeqCst); + let new = set_state(curr, NOTIFIED); + let res = self.state.compare_exchange(curr, new, SeqCst, SeqCst); match res { // No waiters, no further work to do @@ -319,7 +347,43 @@ impl Notify { } /// Notifies all waiting tasks - pub(crate) fn notify_waiters(&self) { + /// + /// If a task is currently waiting, that task is notified. Unlike with + /// `notify()`, no permit is stored to be used by the next call to + /// [`notified().await`]. The purpose of this method is to notify all + /// already registered waiters. Registering for notification is done by + /// acquiring an instance of the `Notified` future via calling `notified()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// let notified1 = notify.notified(); + /// let notified2 = notify.notified(); + /// + /// let handle = tokio::spawn(async move { + /// println!("sending notifications"); + /// notify2.notify_waiters(); + /// }); + /// + /// notified1.await; + /// notified2.await; + /// println!("received notifications"); + /// } + /// ``` + pub fn notify_waiters(&self) { + const NUM_WAKERS: usize = 32; + + let mut wakers: [Option; NUM_WAKERS] = Default::default(); + let mut curr_waker = 0; + // There are waiters, the lock must be acquired to notify. let mut waiters = self.waiters.lock(); @@ -327,34 +391,64 @@ impl Notify { // transition out of WAITING while the lock is held. let curr = self.state.load(SeqCst); - if let EMPTY | NOTIFIED = curr { + if let EMPTY | NOTIFIED = get_state(curr) { // There are no waiting tasks. In this case, no synchronization is // established between `notify` and `notified().await`. + // All we need to do is increment the number of times this + // method was called. + self.state.store(inc_num_notify_waiters_calls(curr), SeqCst); return; } // At this point, it is guaranteed that the state will not // concurrently change, as holding the lock is required to // transition **out** of `WAITING`. - // - // Get pending waiters - while let Some(mut waiter) = waiters.pop_back() { - // Safety: `waiters` lock is still held. - let waiter = unsafe { waiter.as_mut() }; + 'outer: loop { + while curr_waker < NUM_WAKERS { + match waiters.pop_back() { + Some(mut waiter) => { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; - assert!(waiter.notified.is_none()); + assert!(waiter.notified.is_none()); - waiter.notified = Some(NotificationType::AllWaiters); + waiter.notified = Some(NotificationType::AllWaiters); - if let Some(waker) = waiter.waker.take() { - waker.wake(); + if let Some(waker) = waiter.waker.take() { + wakers[curr_waker] = Some(waker); + curr_waker += 1; + } + } + None => { + break 'outer; + } + } } + + drop(waiters); + + for waker in wakers.iter_mut().take(curr_waker) { + waker.take().unwrap().wake(); + } + + curr_waker = 0; + + // Acquire the lock again. + waiters = self.waiters.lock(); } - // All waiters have been notified, the state must be transitioned to + // All waiters will be notified, the state must be transitioned to // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be // held, a `store` is sufficient. - self.state.store(EMPTY, SeqCst); + let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY); + self.state.store(new, SeqCst); + + // Release the lock before notifying + drop(waiters); + + for waker in wakers.iter_mut().take(curr_waker) { + waker.take().unwrap().wake(); + } } } @@ -364,17 +458,18 @@ impl Default for Notify { } } -fn notify_locked(waiters: &mut WaitList, state: &AtomicU8, curr: u8) -> Option { +fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Option { loop { - match curr { + match get_state(curr) { EMPTY | NOTIFIED => { - let res = state.compare_exchange(curr, NOTIFIED, SeqCst, SeqCst); + let res = state.compare_exchange(curr, set_state(curr, NOTIFIED), SeqCst, SeqCst); match res { Ok(_) => return None, Err(actual) => { - assert!(actual == EMPTY || actual == NOTIFIED); - state.store(NOTIFIED, SeqCst); + let actual_state = get_state(actual); + assert!(actual_state == EMPTY || actual_state == NOTIFIED); + state.store(set_state(actual, NOTIFIED), SeqCst); return None; } } @@ -400,7 +495,7 @@ fn notify_locked(waiters: &mut WaitList, state: &AtomicU8, curr: u8) -> Option { // Safety: both `notify` and `state` are `Unpin`. is_unpin::<&Notify>(); - is_unpin::(); + is_unpin::(); let me = self.get_unchecked_mut(); (&me.notify, &mut me.state, &me.waiter) @@ -438,11 +533,16 @@ impl Future for Notified<'_> { loop { match *state { - Init => { + Init(initial_notify_waiters_calls) => { + let curr = notify.state.load(SeqCst); + // Optimistically try acquiring a pending notification - let res = notify - .state - .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst); + let res = notify.state.compare_exchange( + set_state(curr, NOTIFIED), + set_state(curr, EMPTY), + SeqCst, + SeqCst, + ); if res.is_ok() { // Acquired the notification @@ -457,17 +557,27 @@ impl Future for Notified<'_> { // Reload the state with the lock held let mut curr = notify.state.load(SeqCst); + // if notify_waiters has been called after the future + // was created, then we are done + if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls { + *state = Done; + return Poll::Ready(()); + } + // Transition the state to WAITING. loop { - match curr { + match get_state(curr) { EMPTY => { // Transition to WAITING - let res = notify - .state - .compare_exchange(EMPTY, WAITING, SeqCst, SeqCst); + let res = notify.state.compare_exchange( + set_state(curr, EMPTY), + set_state(curr, WAITING), + SeqCst, + SeqCst, + ); if let Err(actual) = res { - assert_eq!(actual, NOTIFIED); + assert_eq!(get_state(actual), NOTIFIED); curr = actual; } else { break; @@ -476,9 +586,12 @@ impl Future for Notified<'_> { WAITING => break, NOTIFIED => { // Try consuming the notification - let res = notify - .state - .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst); + let res = notify.state.compare_exchange( + set_state(curr, NOTIFIED), + set_state(curr, EMPTY), + SeqCst, + SeqCst, + ); match res { Ok(_) => { @@ -487,7 +600,7 @@ impl Future for Notified<'_> { return Poll::Ready(()); } Err(actual) => { - assert_eq!(actual, EMPTY); + assert_eq!(get_state(actual), EMPTY); curr = actual; } } @@ -563,8 +676,8 @@ impl Drop for Notified<'_> { // dropped, which means we must ensure that the waiter entry is no // longer stored in the linked list. if let Waiting = *state { - let mut notify_state = WAITING; let mut waiters = notify.waiters.lock(); + let mut notify_state = notify.state.load(SeqCst); // `Notify.state` may be in any of the three states (Empty, Waiting, // Notified). It doesn't actually matter what the atomic is set to @@ -587,14 +700,14 @@ impl Drop for Notified<'_> { unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) }; if waiters.is_empty() { - notify_state = EMPTY; + notify_state = set_state(notify_state, EMPTY); // If the state *should* be `NOTIFIED`, the call to // `notify_locked` below will end up doing the // `store(NOTIFIED)`. If a concurrent receiver races and // observes the incorrect `EMPTY` state, it will then obtain the // lock and block until `notify.state` is in the correct final // state. - notify.state.store(EMPTY, SeqCst); + notify.state.store(notify_state, SeqCst); } // See if the node was notified but not received. In this case, if diff --git a/tokio/src/sync/tests/loom_notify.rs b/tokio/src/sync/tests/loom_notify.rs index 79a5bf89..4be949a3 100644 --- a/tokio/src/sync/tests/loom_notify.rs +++ b/tokio/src/sync/tests/loom_notify.rs @@ -21,6 +21,27 @@ fn notify_one() { }); } +#[test] +fn notify_waiters() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + let tx = notify.clone(); + let notified1 = notify.notified(); + let notified2 = notify.notified(); + + let th = thread::spawn(move || { + tx.notify_waiters(); + }); + + th.join().unwrap(); + + block_on(async { + notified1.await; + notified2.await; + }); + }); +} + #[test] fn notify_multi() { loom::model(|| { diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index ec73832f..b377ca7f 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -241,25 +241,10 @@ impl Receiver { /// } /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { - use std::future::Future; - use std::pin::Pin; - use std::task::Poll; - // In order to avoid a race condition, we first request a notification, // **then** check the current value's version. If a new version exists, - // the notification request is dropped. Requesting the notification - // requires polling the future once. + // the notification request is dropped. let notified = self.shared.notify_rx.notified(); - pin!(notified); - - // Polling the future once is guaranteed to return `Pending` as `watch` - // only notifies using `notify_waiters`. - crate::future::poll_fn(|cx| { - let res = Pin::new(&mut notified).poll(cx); - assert!(!res.is_ready()); - Poll::Ready(()) - }) - .await; if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { return ret; diff --git a/tokio/tests/sync_notify.rs b/tokio/tests/sync_notify.rs index 8c70fe39..8ffe020f 100644 --- a/tokio/tests/sync_notify.rs +++ b/tokio/tests/sync_notify.rs @@ -100,3 +100,37 @@ fn notified_multi_notify_drop_one() { assert!(notified2.is_woken()); assert_ready!(notified2.poll()); } + +#[test] +fn notify_in_drop_after_wake() { + use futures::task::ArcWake; + use std::future::Future; + use std::sync::Arc; + + let notify = Arc::new(Notify::new()); + + struct NotifyOnDrop(Arc); + + impl ArcWake for NotifyOnDrop { + fn wake_by_ref(_arc_self: &Arc) {} + } + + impl Drop for NotifyOnDrop { + fn drop(&mut self) { + self.0.notify_waiters(); + } + } + + let mut fut = Box::pin(async { + notify.notified().await; + }); + + { + let waker = futures::task::waker(Arc::new(NotifyOnDrop(notify.clone()))); + let mut cx = std::task::Context::from_waker(&waker); + assert!(fut.as_mut().poll(&mut cx).is_pending()); + } + + // Now, notifying **should not** deadlock + notify.notify_waiters(); +} -- cgit v1.2.3