summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorZahari Dichev <zaharidichev@gmail.com>2020-11-16 22:49:35 +0200
committerGitHub <noreply@github.com>2020-11-16 12:49:35 -0800
commitd0ebb4154748166a4ba07baa4b424a1c45efd219 (patch)
tree5ea4d611256290f62baea1a9ffa3333b254181df
parentf5cb4c20422a35b51bfba3391744f8bcb54f7581 (diff)
sync: add `Notify::notify_waiters` (#3098)
This PR makes `Notify::notify_waiters` public. The method already exists, but it changes the way `notify_waiters`, is used. Previously in order for the consumer to register interest, in a notification triggered by `notify_waiters`, the `Notified` future had to be polled. This introduced friction when using the api as the future had to be pinned before polled. This change introduces a counter that tracks how many times `notified_waiters` has been called. Upon creation of the future the number of times is loaded. When first polled the future compares this number with the count state of the `Notify` type. This avoids the need for registering the waiter upfront. Fixes: #3066
-rw-r--r--tokio/src/sync/mpsc/chan.rs19
-rw-r--r--tokio/src/sync/notify.rs207
-rw-r--r--tokio/src/sync/tests/loom_notify.rs21
-rw-r--r--tokio/src/sync/watch.rs17
-rw-r--r--tokio/tests/sync_notify.rs34
5 files changed, 218 insertions, 80 deletions
diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs
index c78fb501..a40f5c3d 100644
--- a/tokio/src/sync/mpsc/chan.rs
+++ b/tokio/src/sync/mpsc/chan.rs
@@ -148,25 +148,10 @@ impl<T, S: Semaphore> Tx<T, S> {
}
pub(crate) async fn closed(&self) {
- use std::future::Future;
- use std::pin::Pin;
- use std::task::Poll;
-
// In order to avoid a race condition, we first request a notification,
- // **then** check the current value's version. If a new version exists,
- // the notification request is dropped. Requesting the notification
- // requires polling the future once.
+ // **then** check whether the semaphore is closed. If the semaphore is
+ // closed the notification request is dropped.
let notified = self.inner.notify_rx_closed.notified();
- pin!(notified);
-
- // Polling the future once is guaranteed to return `Pending` as `watch`
- // only notifies using `notify_waiters`.
- crate::future::poll_fn(|cx| {
- let res = Pin::new(&mut notified).poll(cx);
- assert!(!res.is_ready());
- Poll::Ready(())
- })
- .await;
if self.inner.semaphore.is_closed() {
return;
diff --git a/tokio/src/sync/notify.rs b/tokio/src/sync/notify.rs
index 922f1095..f39f92f8 100644
--- a/tokio/src/sync/notify.rs
+++ b/tokio/src/sync/notify.rs
@@ -5,7 +5,7 @@
// triggers this warning but it is safe to ignore in this case.
#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
-use crate::loom::sync::atomic::AtomicU8;
+use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Mutex;
use crate::util::linked_list::{self, LinkedList};
@@ -109,7 +109,11 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
/// [`Semaphore`]: crate::sync::Semaphore
#[derive(Debug)]
pub struct Notify {
- state: AtomicU8,
+ // This uses 2 bits to store one of `EMPTY`,
+ // `WAITING` or `NOTIFIED`. The rest of the bits
+ // are used to store the number of times `notify_waiters`
+ // was called.
+ state: AtomicUsize,
waiters: Mutex<WaitList>,
}
@@ -154,19 +158,39 @@ unsafe impl<'a> Sync for Notified<'a> {}
#[derive(Debug)]
enum State {
- Init,
+ Init(usize),
Waiting,
Done,
}
+const NOTIFY_WAITERS_SHIFT: usize = 2;
+const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1;
+const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK;
+
/// Initial "idle" state
-const EMPTY: u8 = 0;
+const EMPTY: usize = 0;
/// One or more threads are currently waiting to be notified.
-const WAITING: u8 = 1;
+const WAITING: usize = 1;
/// Pending notification
-const NOTIFIED: u8 = 2;
+const NOTIFIED: usize = 2;
+
+fn set_state(data: usize, state: usize) -> usize {
+ (data & NOTIFY_WAITERS_CALLS_MASK) | (state & STATE_MASK)
+}
+
+fn get_state(data: usize) -> usize {
+ data & STATE_MASK
+}
+
+fn get_num_notify_waiters_calls(data: usize) -> usize {
+ (data & NOTIFY_WAITERS_CALLS_MASK) >> NOTIFY_WAITERS_SHIFT
+}
+
+fn inc_num_notify_waiters_calls(data: usize) -> usize {
+ data + (1 << NOTIFY_WAITERS_SHIFT)
+}
impl Notify {
/// Create a new `Notify`, initialized without a permit.
@@ -180,7 +204,7 @@ impl Notify {
/// ```
pub fn new() -> Notify {
Notify {
- state: AtomicU8::new(0),
+ state: AtomicUsize::new(0),
waiters: Mutex::new(LinkedList::new()),
}
}
@@ -198,7 +222,7 @@ impl Notify {
#[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
pub const fn const_new() -> Notify {
Notify {
- state: AtomicU8::new(0),
+ state: AtomicUsize::new(0),
waiters: Mutex::const_new(LinkedList::new()),
}
}
@@ -239,9 +263,12 @@ impl Notify {
/// }
/// ```
pub fn notified(&self) -> Notified<'_> {
+ // we load the number of times notify_waiters
+ // was called and store that in our initial state
+ let state = self.state.load(SeqCst);
Notified {
notify: self,
- state: State::Init,
+ state: State::Init(state >> NOTIFY_WAITERS_SHIFT),
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
@@ -290,11 +317,12 @@ impl Notify {
let mut curr = self.state.load(SeqCst);
// If the state is `EMPTY`, transition to `NOTIFIED` and return.
- while let EMPTY | NOTIFIED = curr {
+ while let EMPTY | NOTIFIED = get_state(curr) {
// The compare-exchange from `NOTIFIED` -> `NOTIFIED` is intended. A
// happens-before synchronization must happen between this atomic
// operation and a task calling `notified().await`.
- let res = self.state.compare_exchange(curr, NOTIFIED, SeqCst, SeqCst);
+ let new = set_state(curr, NOTIFIED);
+ let res = self.state.compare_exchange(curr, new, SeqCst, SeqCst);
match res {
// No waiters, no further work to do
@@ -319,7 +347,43 @@ impl Notify {
}
/// Notifies all waiting tasks
- pub(crate) fn notify_waiters(&self) {
+ ///
+ /// If a task is currently waiting, that task is notified. Unlike with
+ /// `notify()`, no permit is stored to be used by the next call to
+ /// [`notified().await`]. The purpose of this method is to notify all
+ /// already registered waiters. Registering for notification is done by
+ /// acquiring an instance of the `Notified` future via calling `notified()`.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Notify;
+ /// use std::sync::Arc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let notify = Arc::new(Notify::new());
+ /// let notify2 = notify.clone();
+ ///
+ /// let notified1 = notify.notified();
+ /// let notified2 = notify.notified();
+ ///
+ /// let handle = tokio::spawn(async move {
+ /// println!("sending notifications");
+ /// notify2.notify_waiters();
+ /// });
+ ///
+ /// notified1.await;
+ /// notified2.await;
+ /// println!("received notifications");
+ /// }
+ /// ```
+ pub fn notify_waiters(&self) {
+ const NUM_WAKERS: usize = 32;
+
+ let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default();
+ let mut curr_waker = 0;
+
// There are waiters, the lock must be acquired to notify.
let mut waiters = self.waiters.lock();
@@ -327,34 +391,64 @@ impl Notify {
// transition out of WAITING while the lock is held.
let curr = self.state.load(SeqCst);
- if let EMPTY | NOTIFIED = curr {
+ if let EMPTY | NOTIFIED = get_state(curr) {
// There are no waiting tasks. In this case, no synchronization is
// established between `notify` and `notified().await`.
+ // All we need to do is increment the number of times this
+ // method was called.
+ self.state.store(inc_num_notify_waiters_calls(curr), SeqCst);
return;
}
// At this point, it is guaranteed that the state will not
// concurrently change, as holding the lock is required to
// transition **out** of `WAITING`.
- //
- // Get pending waiters
- while let Some(mut waiter) = waiters.pop_back() {
- // Safety: `waiters` lock is still held.
- let waiter = unsafe { waiter.as_mut() };
+ 'outer: loop {
+ while curr_waker < NUM_WAKERS {
+ match waiters.pop_back() {
+ Some(mut waiter) => {
+ // Safety: `waiters` lock is still held.
+ let waiter = unsafe { waiter.as_mut() };
- assert!(waiter.notified.is_none());
+ assert!(waiter.notified.is_none());
- waiter.notified = Some(NotificationType::AllWaiters);
+ waiter.notified = Some(NotificationType::AllWaiters);
- if let Some(waker) = waiter.waker.take() {
- waker.wake();
+ if let Some(waker) = waiter.waker.take() {
+ wakers[curr_waker] = Some(waker);
+ curr_waker += 1;
+ }
+ }
+ None => {
+ break 'outer;
+ }
+ }
}
+
+ drop(waiters);
+
+ for waker in wakers.iter_mut().take(curr_waker) {
+ waker.take().unwrap().wake();
+ }
+
+ curr_waker = 0;
+
+ // Acquire the lock again.
+ waiters = self.waiters.lock();
}
- // All waiters have been notified, the state must be transitioned to
+ // All waiters will be notified, the state must be transitioned to
// `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
// held, a `store` is sufficient.
- self.state.store(EMPTY, SeqCst);
+ let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
+ self.state.store(new, SeqCst);
+
+ // Release the lock before notifying
+ drop(waiters);
+
+ for waker in wakers.iter_mut().take(curr_waker) {
+ waker.take().unwrap().wake();
+ }
}
}
@@ -364,17 +458,18 @@ impl Default for Notify {
}
}
-fn notify_locked(waiters: &mut WaitList, state: &AtomicU8, curr: u8) -> Option<Waker> {
+fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Option<Waker> {
loop {
- match curr {
+ match get_state(curr) {
EMPTY | NOTIFIED => {
- let res = state.compare_exchange(curr, NOTIFIED, SeqCst, SeqCst);
+ let res = state.compare_exchange(curr, set_state(curr, NOTIFIED), SeqCst, SeqCst);
match res {
Ok(_) => return None,
Err(actual) => {
- assert!(actual == EMPTY || actual == NOTIFIED);
- state.store(NOTIFIED, SeqCst);
+ let actual_state = get_state(actual);
+ assert!(actual_state == EMPTY || actual_state == NOTIFIED);
+ state.store(set_state(actual, NOTIFIED), SeqCst);
return None;
}
}
@@ -400,7 +495,7 @@ fn notify_locked(waiters: &mut WaitList, state: &AtomicU8, curr: u8) -> Option<W
// must be transitioned to `EMPTY`. As transitioning
// **from** `WAITING` requires the lock to be held, a
// `store` is sufficient.
- state.store(EMPTY, SeqCst);
+ state.store(set_state(curr, EMPTY), SeqCst);
}
return waker;
@@ -420,7 +515,7 @@ impl Notified<'_> {
// Safety: both `notify` and `state` are `Unpin`.
is_unpin::<&Notify>();
- is_unpin::<AtomicU8>();
+ is_unpin::<AtomicUsize>();
let me = self.get_unchecked_mut();
(&me.notify, &mut me.state, &me.waiter)
@@ -438,11 +533,16 @@ impl Future for Notified<'_> {
loop {
match *state {
- Init => {
+ Init(initial_notify_waiters_calls) => {
+ let curr = notify.state.load(SeqCst);
+
// Optimistically try acquiring a pending notification
- let res = notify
- .state
- .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst);
+ let res = notify.state.compare_exchange(
+ set_state(curr, NOTIFIED),
+ set_state(curr, EMPTY),
+ SeqCst,
+ SeqCst,
+ );
if res.is_ok() {
// Acquired the notification
@@ -457,17 +557,27 @@ impl Future for Notified<'_> {
// Reload the state with the lock held
let mut curr = notify.state.load(SeqCst);
+ // if notify_waiters has been called after the future
+ // was created, then we are done
+ if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls {
+ *state = Done;
+ return Poll::Ready(());
+ }
+
// Transition the state to WAITING.
loop {
- match curr {
+ match get_state(curr) {
EMPTY => {
// Transition to WAITING
- let res = notify
- .state
- .compare_exchange(EMPTY, WAITING, SeqCst, SeqCst);
+ let res = notify.state.compare_exchange(
+ set_state(curr, EMPTY),
+ set_state(curr, WAITING),
+ SeqCst,
+ SeqCst,
+ );
if let Err(actual) = res {
- assert_eq!(actual, NOTIFIED);
+ assert_eq!(get_state(actual), NOTIFIED);
curr = actual;
} else {
break;
@@ -476,9 +586,12 @@ impl Future for Notified<'_> {
WAITING => break,
NOTIFIED => {
// Try consuming the notification
- let res = notify
- .state
- .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst);
+ let res = notify.state.compare_exchange(
+ set_state(curr, NOTIFIED),
+ set_state(curr, EMPTY),
+ SeqCst,
+ SeqCst,
+ );
match res {
Ok(_) => {
@@ -487,7 +600,7 @@ impl Future for Notified<'_> {
return Poll::Ready(());
}
Err(actual) => {
- assert_eq!(actual, EMPTY);
+ assert_eq!(get_state(actual), EMPTY);
curr = actual;
}
}
@@ -563,8 +676,8 @@ impl Drop for Notified<'_> {
// dropped, which means we must ensure that the waiter entry is no
// longer stored in the linked list.
if let Waiting = *state {
- let mut notify_state = WAITING;
let mut waiters = notify.waiters.lock();
+ let mut notify_state = notify.state.load(SeqCst);
// `Notify.state` may be in any of the three states (Empty, Waiting,
// Notified). It doesn't actually matter what the atomic is set to
@@ -587,14 +700,14 @@ impl Drop for Notified<'_> {
unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) };
if waiters.is_empty() {
- notify_state = EMPTY;
+ notify_state = set_state(notify_state, EMPTY);
// If the state *should* be `NOTIFIED`, the call to
// `notify_locked` below will end up doing the
// `store(NOTIFIED)`. If a concurrent receiver races and
// observes the incorrect `EMPTY` state, it will then obtain the
// lock and block until `notify.state` is in the correct final
// state.
- notify.state.store(EMPTY, SeqCst);
+ notify.state.store(notify_state, SeqCst);
}
// See if the node was notified but not received. In this case, if
diff --git a/tokio/src/sync/tests/loom_notify.rs b/tokio/src/sync/tests/loom_notify.rs
index 79a5bf89..4be949a3 100644
--- a/tokio/src/sync/tests/loom_notify.rs
+++ b/tokio/src/sync/tests/loom_notify.rs
@@ -22,6 +22,27 @@ fn notify_one() {
}
#[test]
+fn notify_waiters() {
+ loom::model(|| {
+ let notify = Arc::new(Notify::new());
+ let tx = notify.clone();
+ let notified1 = notify.notified();
+ let notified2 = notify.notified();
+
+ let th = thread::spawn(move || {
+ tx.notify_waiters();
+ });
+
+ th.join().unwrap();
+
+ block_on(async {
+ notified1.await;
+ notified2.await;
+ });
+ });
+}
+
+#[test]
fn notify_multi() {
loom::model(|| {
let notify = Arc::new(Notify::new());
diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs
index ec73832f..b377ca7f 100644
--- a/tokio/src/sync/watch.rs
+++ b/tokio/src/sync/watch.rs
@@ -241,25 +241,10 @@ impl<T> Receiver<T> {
/// }
/// ```
pub async fn changed(&mut self) -> Result<(), error::RecvError> {
- use std::future::Future;
- use std::pin::Pin;
- use std::task::Poll;
-
// In order to avoid a race condition, we first request a notification,
// **then** check the current value's version. If a new version exists,
- // the notification request is dropped. Requesting the notification
- // requires polling the future once.
+ // the notification request is dropped.
let notified = self.shared.notify_rx.notified();
- pin!(notified);
-
- // Polling the future once is guaranteed to return `Pending` as `watch`
- // only notifies using `notify_waiters`.
- crate::future::poll_fn(|cx| {
- let res = Pin::new(&mut notified).poll(cx);
- assert!(!res.is_ready());
- Poll::Ready(())
- })
- .await;
if let Some(ret) = maybe_changed(&self.shared, &mut self.version) {
return ret;
diff --git a/tokio/tests/sync_notify.rs b/tokio/tests/sync_notify.rs
index 8c70fe39..8ffe020f 100644
--- a/tokio/tests/sync_notify.rs
+++ b/tokio/tests/sync_notify.rs
@@ -100,3 +100,37 @@ fn notified_multi_notify_drop_one() {
assert!(notified2.is_woken());
assert_ready!(notified2.poll());
}
+
+#[test]
+fn notify_in_drop_after_wake() {
+ use futures::task::ArcWake;
+ use std::future::Future;
+ use std::sync::Arc;
+
+ let notify = Arc::new(Notify::new());
+
+ struct NotifyOnDrop(Arc<Notify>);
+
+ impl ArcWake for NotifyOnDrop {
+ fn wake_by_ref(_arc_self: &Arc<Self>) {}
+ }
+
+ impl Drop for NotifyOnDrop {
+ fn drop(&mut self) {
+ self.0.notify_waiters();
+ }
+ }
+
+ let mut fut = Box::pin(async {
+ notify.notified().await;
+ });
+
+ {
+ let waker = futures::task::waker(Arc::new(NotifyOnDrop(notify.clone())));
+ let mut cx = std::task::Context::from_waker(&waker);
+ assert!(fut.as_mut().poll(&mut cx).is_pending());
+ }
+
+ // Now, notifying **should not** deadlock
+ notify.notify_waiters();
+}