diff options
Diffstat (limited to 'tokio/src/sync/notify.rs')
-rw-r--r-- | tokio/src/sync/notify.rs | 72 |
1 files changed, 52 insertions, 20 deletions
diff --git a/tokio/src/sync/notify.rs b/tokio/src/sync/notify.rs index f3c1bda1..ef34ad2d 100644 --- a/tokio/src/sync/notify.rs +++ b/tokio/src/sync/notify.rs @@ -2,8 +2,11 @@ use crate::loom::sync::atomic::AtomicU8; use crate::loom::sync::Mutex; use crate::util::linked_list::{self, LinkedList}; +use std::cell::UnsafeCell; use std::future::Future; +use std::marker::PhantomPinned; use std::pin::Pin; +use std::ptr::NonNull; use std::sync::atomic::Ordering::SeqCst; use std::task::{Context, Poll, Waker}; @@ -103,11 +106,17 @@ pub struct Notify { #[derive(Debug)] struct Waiter { + /// Intrusive linked-list pointers + pointers: linked_list::Pointers<Waiter>, + /// Waiting task's waker waker: Option<Waker>, /// `true` if the notification has been assigned to this waiter. notified: bool, + + /// Should not be `Unpin`. + _p: PhantomPinned, } /// Future returned from `notified()` @@ -120,9 +129,12 @@ struct Notified<'a> { state: State, /// Entry in the waiter `LinkedList`. - waiter: linked_list::Entry<Waiter>, + waiter: UnsafeCell<Waiter>, } +unsafe impl<'a> Send for Notified<'a> {} +unsafe impl<'a> Sync for Notified<'a> {} + #[derive(Debug)] enum State { Init, @@ -189,9 +201,11 @@ impl Notify { Notified { notify: self, state: State::Init, - waiter: linked_list::Entry::new(Waiter { + waiter: UnsafeCell::new(Waiter { + pointers: linked_list::Pointers::new(), waker: None, notified: false, + _p: PhantomPinned, }), } .await @@ -292,7 +306,10 @@ fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) - // transition **out** of `WAITING`. // // Get a pending waiter - let mut waiter = waiters.pop_back().unwrap(); + let waiter = waiters.pop_back().unwrap(); + + // Safety: `waiters` lock is still held. + let waiter = unsafe { &mut *waiter }; assert!(!waiter.notified); @@ -319,9 +336,7 @@ fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) - impl Notified<'_> { /// A custom `project` implementation is used in place of `pin-project-lite` /// as a custom drop implementation is needed. - fn project( - self: Pin<&mut Self>, - ) -> (&Notify, &mut State, Pin<&mut linked_list::Entry<Waiter>>) { + fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) { unsafe { // Safety: both `notify` and `state` are `Unpin`. @@ -329,11 +344,7 @@ impl Notified<'_> { is_unpin::<AtomicU8>(); let me = self.get_unchecked_mut(); - ( - &me.notify, - &mut me.state, - Pin::new_unchecked(&mut me.waiter), - ) + (&me.notify, &mut me.state, &me.waiter) } } } @@ -344,7 +355,7 @@ impl Future for Notified<'_> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { use State::*; - let (notify, state, mut waiter) = self.project(); + let (notify, state, waiter) = self.project(); loop { match *state { @@ -408,12 +419,12 @@ impl Future for Notified<'_> { // Safety: called while locked. unsafe { - (*waiter.as_mut().get()).waker = Some(cx.waker().clone()); - - // Insert the waiter into the linked list - waiters.push_front(waiter.as_mut()); + (*waiter.get()).waker = Some(cx.waker().clone()); } + // Insert the waiter into the linked list + waiters.push_front(waiter.get()); + *state = Waiting; } Waiting => { @@ -425,7 +436,7 @@ impl Future for Notified<'_> { let waiters = notify.waiters.lock().unwrap(); // Safety: called while locked - let w = unsafe { &mut *waiter.as_mut().get() }; + let w = unsafe { &mut *waiter.get() }; if w.notified { // Our waker has been notified. Reset the fields and @@ -463,7 +474,7 @@ impl Drop for Notified<'_> { use State::*; // Safety: The type only transitions to a "Waiting" state when pinned. - let (notify, state, mut waiter) = unsafe { Pin::new_unchecked(self).project() }; + let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() }; // This is where we ensure safety. The `Notified` value is being // dropped, which means we must ensure that the waiter entry is no @@ -490,7 +501,7 @@ impl Drop for Notified<'_> { // // safety: the waiter is only added to `waiters` by virtue of it // being the only `LinkedList` available to the type. - unsafe { waiters.remove(waiter.as_mut()) }; + unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) }; if waiters.is_empty() { notify_state = EMPTY; @@ -508,7 +519,7 @@ impl Drop for Notified<'_> { // // Safety: with the entry removed from the linked list, there can be // no concurrent access to the entry - let notified = unsafe { (*waiter.as_mut().get()).notified }; + let notified = unsafe { (*waiter.get()).notified }; if notified { if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) { @@ -520,4 +531,25 @@ impl Drop for Notified<'_> { } } +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = *mut Waiter; + type Target = Waiter; + + fn to_raw(handle: *mut Waiter) -> NonNull<Waiter> { + debug_assert!(!handle.is_null()); + unsafe { NonNull::new_unchecked(handle) } + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> *mut Waiter { + ptr.as_ptr() + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } +} + fn is_unpin<T: Unpin>() {} |