summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--tokio/src/sync/batch_semaphore.rs146
-rw-r--r--tokio/src/sync/tests/loom_semaphore_batch.rs44
-rw-r--r--tokio/src/util/linked_list.rs178
3 files changed, 135 insertions, 233 deletions
diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs
index 3656c109..5d15311d 100644
--- a/tokio/src/sync/batch_semaphore.rs
+++ b/tokio/src/sync/batch_semaphore.rs
@@ -129,13 +129,8 @@ impl Semaphore {
return;
}
- // Assign permits to the wait queue, returning a list containing all the
- // waiters at the back of the queue that received enough permits to wake
- // up.
- let notified = self.add_permits_locked(added, self.waiters.lock().unwrap());
-
- // Once we release the lock, notify all woken waiters.
- notify_all(notified);
+ // Assign permits to the wait queue
+ self.add_permits_locked(added, self.waiters.lock().unwrap());
}
/// Closes the semaphore. This prevents the semaphore from issuing new
@@ -144,20 +139,22 @@ impl Semaphore {
// semaphore implementation.
#[allow(dead_code)]
pub(crate) fn close(&self) {
- let notified = {
- let mut waiters = self.waiters.lock().unwrap();
- // If the semaphore's permits counter has enough permits for an
- // unqueued waiter to acquire all the permits it needs immediately,
- // it won't touch the wait list. Therefore, we have to set a bit on
- // the permit counter as well. However, we must do this while
- // holding the lock --- otherwise, if we set the bit and then wait
- // to acquire the lock we'll enter an inconsistent state where the
- // permit counter is closed, but the wait list is not.
- self.permits.fetch_or(Self::CLOSED, Release);
- waiters.closed = true;
- waiters.queue.take_all()
- };
- notify_all(notified)
+ let mut waiters = self.waiters.lock().unwrap();
+ // If the semaphore's permits counter has enough permits for an
+ // unqueued waiter to acquire all the permits it needs immediately,
+ // it won't touch the wait list. Therefore, we have to set a bit on
+ // the permit counter as well. However, we must do this while
+ // holding the lock --- otherwise, if we set the bit and then wait
+ // to acquire the lock we'll enter an inconsistent state where the
+ // permit counter is closed, but the wait list is not.
+ self.permits.fetch_or(Self::CLOSED, Release);
+ waiters.closed = true;
+ while let Some(mut waiter) = waiters.queue.pop_back() {
+ let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) };
+ if let Some(waker) = waker {
+ waker.wake();
+ }
+ }
}
pub(crate) fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
@@ -189,58 +186,60 @@ impl Semaphore {
/// Release `rem` permits to the semaphore's wait list, starting from the
/// end of the queue.
- ///
- /// This returns a new `LinkedList` containing all the waiters that received
- /// enough permits to be notified. Once the lock on the wait list is
- /// released, this list should be drained and the waiters in it notified.
///
/// If `rem` exceeds the number of permits needed by the wait list, the
/// remainder are assigned back to the semaphore.
- fn add_permits_locked(
- &self,
- mut rem: usize,
- mut waiters: MutexGuard<'_, Waitlist>,
- ) -> LinkedList<Waiter> {
- // Starting from the back of the wait queue, assign each waiter as many
- // permits as it needs until we run out of permits to assign.
- let mut last = None;
- for waiter in waiters.queue.iter().rev() {
- // Was the waiter assigned enough permits to wake it?
- if !waiter.assign_permits(&mut rem) {
- break;
+ fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) {
+ let mut wakers: [Option<Waker>; 8] = Default::default();
+ let mut lock = Some(waiters);
+ let mut is_empty = false;
+ while rem > 0 {
+ let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap());
+ 'inner: for slot in &mut wakers[..] {
+ // Was the waiter assigned enough permits to wake it?
+ match waiters.queue.last() {
+ Some(waiter) => {
+ if !waiter.assign_permits(&mut rem) {
+ break 'inner;
+ }
+ }
+ None => {
+ is_empty = true;
+ // If we assigned permits to all the waiters in the queue, and there are
+ // still permits left over, assign them back to the semaphore.
+ break 'inner;
+ }
+ };
+ let mut waiter = waiters.queue.pop_back().unwrap();
+ *slot = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) };
}
- last = Some(NonNull::from(waiter));
- }
- // If we assigned permits to all the waiters in the queue, and there are
- // still permits left over, assign them back to the semaphore.
- if rem > 0 {
- let permits = rem << Self::PERMIT_SHIFT;
- assert!(
- permits < Self::MAX_PERMITS,
- "cannot add more than MAX_PERMITS permits ({})",
- Self::MAX_PERMITS
- );
- let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release);
- assert!(
- prev + permits <= Self::MAX_PERMITS,
- "number of added permits ({}) would overflow MAX_PERMITS ({})",
- rem,
- Self::MAX_PERMITS
- );
- }
+ if rem > 0 && is_empty {
+ let permits = rem << Self::PERMIT_SHIFT;
+ assert!(
+ permits < Self::MAX_PERMITS,
+ "cannot add more than MAX_PERMITS permits ({})",
+ Self::MAX_PERMITS
+ );
+ let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release);
+ assert!(
+ prev + permits <= Self::MAX_PERMITS,
+ "number of added permits ({}) would overflow MAX_PERMITS ({})",
+ rem,
+ Self::MAX_PERMITS
+ );
+ rem = 0;
+ }
- // Split off the queue at the last waiter that was satisfied, creating a
- // new list. Once we release the lock, we'll drain this list and notify
- // the waiters in it.
- if let Some(waiter) = last {
- // Safety: it's only safe to call `split_back` with a pointer to a
- // node in the same list as the one we call `split_back` on. Since
- // we got the waiter pointer from the list's iterator, this is fine.
- unsafe { waiters.queue.split_back(waiter) }
- } else {
- LinkedList::new()
+ drop(waiters); // release the lock
+
+ wakers
+ .iter_mut()
+ .filter_map(Option::take)
+ .for_each(Waker::wake);
}
+
+ assert_eq!(rem, 0);
}
fn poll_acquire(
@@ -354,18 +353,6 @@ impl fmt::Debug for Semaphore {
}
}
-/// Pop all waiters from `list`, starting at the end of the queue, and notify
-/// them.
-fn notify_all(mut list: LinkedList<Waiter>) {
- while let Some(waiter) = list.pop_back() {
- let waker = unsafe { waiter.as_ref().waker.with_mut(|waker| (*waker).take()) };
-
- waker
- .expect("if a node is in the wait list, it must have a waker")
- .wake();
- }
-}
-
impl Waiter {
fn new(num_permits: u16) -> Self {
Waiter {
@@ -471,8 +458,7 @@ impl Drop for Acquire<'_> {
let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire);
if acquired_permits > 0 {
- let notified = self.semaphore.add_permits_locked(acquired_permits, waiters);
- notify_all(notified);
+ self.semaphore.add_permits_locked(acquired_permits, waiters);
}
}
}
diff --git a/tokio/src/sync/tests/loom_semaphore_batch.rs b/tokio/src/sync/tests/loom_semaphore_batch.rs
index 4c1936c5..76a1bc00 100644
--- a/tokio/src/sync/tests/loom_semaphore_batch.rs
+++ b/tokio/src/sync/tests/loom_semaphore_batch.rs
@@ -114,6 +114,50 @@ fn concurrent_close() {
}
#[test]
+fn concurrent_cancel() {
+ async fn poll_and_cancel(semaphore: Arc<Semaphore>) {
+ let mut acquire1 = Some(semaphore.acquire(1));
+ let mut acquire2 = Some(semaphore.acquire(1));
+ poll_fn(|cx| {
+ // poll the acquire future once, and then immediately throw
+ // it away. this simulates a situation where a future is
+ // polled and then cancelled, such as by a timeout.
+ if let Some(acquire) = acquire1.take() {
+ pin!(acquire);
+ let _ = acquire.poll(cx);
+ }
+ if let Some(acquire) = acquire2.take() {
+ pin!(acquire);
+ let _ = acquire.poll(cx);
+ }
+ Poll::Ready(())
+ })
+ .await
+ }
+
+ loom::model(|| {
+ let semaphore = Arc::new(Semaphore::new(0));
+ let t1 = {
+ let semaphore = semaphore.clone();
+ thread::spawn(move || block_on(poll_and_cancel(semaphore)))
+ };
+ let t2 = {
+ let semaphore = semaphore.clone();
+ thread::spawn(move || block_on(poll_and_cancel(semaphore)))
+ };
+ let t3 = {
+ let semaphore = semaphore.clone();
+ thread::spawn(move || block_on(poll_and_cancel(semaphore)))
+ };
+
+ t1.join().unwrap();
+ semaphore.release(10);
+ t2.join().unwrap();
+ t3.join().unwrap();
+ });
+}
+
+#[test]
fn batch() {
let mut b = loom::model::Builder::new();
b.preemption_bound = Some(1);
diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs
index 1a488032..aa3ce771 100644
--- a/tokio/src/util/linked_list.rs
+++ b/tokio/src/util/linked_list.rs
@@ -164,46 +164,6 @@ impl<T: Link> LinkedList<T> {
}
}
-cfg_sync! {
- impl<T: Link> LinkedList<T> {
- /// Splits this list off at `node`, returning a new list with `node` at its
- /// front.
- ///
- /// If `node` is at the the front of this list, then this list will be empty after
- /// splitting. If `node` is the last node in this list, then the returned
- /// list will contain only `node`.
- ///
- /// # Safety
- ///
- /// The caller **must** ensure that `node` is currently contained by
- /// `self` or not contained by any other list.
- pub(crate) unsafe fn split_back(&mut self, node: NonNull<T::Target>) -> Self {
- let new_tail = T::pointers(node).as_mut().prev.take().map(|prev| {
- T::pointers(prev).as_mut().next = None;
- prev
- });
- if new_tail.is_none() {
- self.head = None;
- }
- let tail = std::mem::replace(&mut self.tail, new_tail);
- Self {
- head: Some(node),
- tail,
- }
- }
-
- /// Takes all entries from this list, returning a new list.
- ///
- /// This list will be left empty.
- pub(crate) fn take_all(&mut self) -> Self {
- Self {
- head: self.head.take(),
- tail: self.tail.take(),
- }
- }
- }
-}
-
impl<T: Link> fmt::Debug for LinkedList<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LinkedList")
@@ -213,49 +173,41 @@ impl<T: Link> fmt::Debug for LinkedList<T> {
}
}
-// ===== impl Iter =====
-
-#[cfg(any(feature = "sync", feature = "rt-threaded"))]
-pub(crate) struct Iter<'a, T: Link> {
- curr: Option<NonNull<T::Target>>,
- #[cfg(feature = "sync")]
- curr_back: Option<NonNull<T::Target>>,
- _p: core::marker::PhantomData<&'a T>,
-}
-
-#[cfg(any(feature = "sync", feature = "rt-threaded"))]
-impl<T: Link> LinkedList<T> {
- pub(crate) fn iter(&self) -> Iter<'_, T> {
- Iter {
- curr: self.head,
- #[cfg(feature = "sync")]
- curr_back: self.tail,
- _p: core::marker::PhantomData,
+cfg_sync! {
+ impl<T: Link> LinkedList<T> {
+ pub(crate) fn last(&self) -> Option<&T::Target> {
+ let tail = self.tail.as_ref()?;
+ unsafe {
+ Some(&*tail.as_ptr())
+ }
}
}
}
-#[cfg(any(feature = "sync", feature = "rt-threaded"))]
-impl<'a, T: Link> Iterator for Iter<'a, T> {
- type Item = &'a T::Target;
+// ===== impl Iter =====
- fn next(&mut self) -> Option<&'a T::Target> {
- let curr = self.curr?;
- // safety: the pointer references data contained by the list
- self.curr = unsafe { T::pointers(curr).as_ref() }.next;
+cfg_rt_threaded! {
+ pub(crate) struct Iter<'a, T: Link> {
+ curr: Option<NonNull<T::Target>>,
+ _p: core::marker::PhantomData<&'a T>,
+ }
- // safety: the value is still owned by the linked list.
- Some(unsafe { &*curr.as_ptr() })
+ impl<T: Link> LinkedList<T> {
+ pub(crate) fn iter(&self) -> Iter<'_, T> {
+ Iter {
+ curr: self.head,
+ _p: core::marker::PhantomData,
+ }
+ }
}
-}
-cfg_sync! {
- impl<'a, T: Link> DoubleEndedIterator for Iter<'a, T> {
- fn next_back(&mut self) -> Option<&'a T::Target> {
- let curr = self.curr_back?;
+ impl<'a, T: Link> Iterator for Iter<'a, T> {
+ type Item = &'a T::Target;
+ fn next(&mut self) -> Option<&'a T::Target> {
+ let curr = self.curr?;
// safety: the pointer references data contained by the list
- self.curr_back = unsafe { T::pointers(curr).as_ref() }.prev;
+ self.curr = unsafe { T::pointers(curr).as_ref() }.next;
// safety: the value is still owned by the linked list.
Some(unsafe { &*curr.as_ptr() })
@@ -564,86 +516,6 @@ mod tests {
assert!(i.next().is_none());
}
- #[test]
- fn split_back() {
- let a = entry(1);
- let b = entry(2);
- let c = entry(3);
- let d = entry(4);
-
- {
- let mut list1 = LinkedList::<&Entry>::new();
-
- push_all(
- &mut list1,
- &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()],
- );
- let mut list2 = unsafe { list1.split_back(ptr(&a)) };
-
- assert_eq!([2, 3, 4].to_vec(), collect_list(&mut list1));
- assert_eq!([1].to_vec(), collect_list(&mut list2));
- }
-
- {
- let mut list1 = LinkedList::<&Entry>::new();
-
- push_all(
- &mut list1,
- &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()],
- );
- let mut list2 = unsafe { list1.split_back(ptr(&b)) };
-
- assert_eq!([3, 4].to_vec(), collect_list(&mut list1));
- assert_eq!([1, 2].to_vec(), collect_list(&mut list2));
- }
-
- {
- let mut list1 = LinkedList::<&Entry>::new();
-
- push_all(
- &mut list1,
- &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()],
- );
- let mut list2 = unsafe { list1.split_back(ptr(&c)) };
-
- assert_eq!([4].to_vec(), collect_list(&mut list1));
- assert_eq!([1, 2, 3].to_vec(), collect_list(&mut list2));
- }
-
- {
- let mut list1 = LinkedList::<&Entry>::new();
-
- push_all(
- &mut list1,
- &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()],
- );
- let mut list2 = unsafe { list1.split_back(ptr(&d)) };
-
- assert_eq!(Vec::<i32>::new(), collect_list(&mut list1));
- assert_eq!([1, 2, 3, 4].to_vec(), collect_list(&mut list2));
- }
- }
-
- #[test]
- fn take_all() {
- let mut list1 = LinkedList::<&Entry>::new();
- let a = entry(1);
- let b = entry(2);
-
- list1.push_front(a.as_ref());
- list1.push_front(b.as_ref());
-
- assert!(!list1.is_empty());
-
- let mut list2 = list1.take_all();
-
- assert!(list1.is_empty());
- assert!(!list2.is_empty());
-
- assert_eq!(Vec::<i32>::new(), collect_list(&mut list1));
- assert_eq!([1, 2].to_vec(), collect_list(&mut list2));
- }
-
proptest::proptest! {
#[test]
fn fuzz_linked_list(ops: Vec<usize>) {