diff options
-rw-r--r-- | tokio/src/sync/batch_semaphore.rs | 146 | ||||
-rw-r--r-- | tokio/src/sync/tests/loom_semaphore_batch.rs | 44 | ||||
-rw-r--r-- | tokio/src/util/linked_list.rs | 178 |
3 files changed, 135 insertions, 233 deletions
diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index 3656c109..5d15311d 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -129,13 +129,8 @@ impl Semaphore { return; } - // Assign permits to the wait queue, returning a list containing all the - // waiters at the back of the queue that received enough permits to wake - // up. - let notified = self.add_permits_locked(added, self.waiters.lock().unwrap()); - - // Once we release the lock, notify all woken waiters. - notify_all(notified); + // Assign permits to the wait queue + self.add_permits_locked(added, self.waiters.lock().unwrap()); } /// Closes the semaphore. This prevents the semaphore from issuing new @@ -144,20 +139,22 @@ impl Semaphore { // semaphore implementation. #[allow(dead_code)] pub(crate) fn close(&self) { - let notified = { - let mut waiters = self.waiters.lock().unwrap(); - // If the semaphore's permits counter has enough permits for an - // unqueued waiter to acquire all the permits it needs immediately, - // it won't touch the wait list. Therefore, we have to set a bit on - // the permit counter as well. However, we must do this while - // holding the lock --- otherwise, if we set the bit and then wait - // to acquire the lock we'll enter an inconsistent state where the - // permit counter is closed, but the wait list is not. - self.permits.fetch_or(Self::CLOSED, Release); - waiters.closed = true; - waiters.queue.take_all() - }; - notify_all(notified) + let mut waiters = self.waiters.lock().unwrap(); + // If the semaphore's permits counter has enough permits for an + // unqueued waiter to acquire all the permits it needs immediately, + // it won't touch the wait list. Therefore, we have to set a bit on + // the permit counter as well. However, we must do this while + // holding the lock --- otherwise, if we set the bit and then wait + // to acquire the lock we'll enter an inconsistent state where the + // permit counter is closed, but the wait list is not. + self.permits.fetch_or(Self::CLOSED, Release); + waiters.closed = true; + while let Some(mut waiter) = waiters.queue.pop_back() { + let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; + if let Some(waker) = waker { + waker.wake(); + } + } } pub(crate) fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { @@ -189,58 +186,60 @@ impl Semaphore { /// Release `rem` permits to the semaphore's wait list, starting from the /// end of the queue. - /// - /// This returns a new `LinkedList` containing all the waiters that received - /// enough permits to be notified. Once the lock on the wait list is - /// released, this list should be drained and the waiters in it notified. /// /// If `rem` exceeds the number of permits needed by the wait list, the /// remainder are assigned back to the semaphore. - fn add_permits_locked( - &self, - mut rem: usize, - mut waiters: MutexGuard<'_, Waitlist>, - ) -> LinkedList<Waiter> { - // Starting from the back of the wait queue, assign each waiter as many - // permits as it needs until we run out of permits to assign. - let mut last = None; - for waiter in waiters.queue.iter().rev() { - // Was the waiter assigned enough permits to wake it? - if !waiter.assign_permits(&mut rem) { - break; + fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { + let mut wakers: [Option<Waker>; 8] = Default::default(); + let mut lock = Some(waiters); + let mut is_empty = false; + while rem > 0 { + let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap()); + 'inner: for slot in &mut wakers[..] { + // Was the waiter assigned enough permits to wake it? + match waiters.queue.last() { + Some(waiter) => { + if !waiter.assign_permits(&mut rem) { + break 'inner; + } + } + None => { + is_empty = true; + // If we assigned permits to all the waiters in the queue, and there are + // still permits left over, assign them back to the semaphore. + break 'inner; + } + }; + let mut waiter = waiters.queue.pop_back().unwrap(); + *slot = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; } - last = Some(NonNull::from(waiter)); - } - // If we assigned permits to all the waiters in the queue, and there are - // still permits left over, assign them back to the semaphore. - if rem > 0 { - let permits = rem << Self::PERMIT_SHIFT; - assert!( - permits < Self::MAX_PERMITS, - "cannot add more than MAX_PERMITS permits ({})", - Self::MAX_PERMITS - ); - let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release); - assert!( - prev + permits <= Self::MAX_PERMITS, - "number of added permits ({}) would overflow MAX_PERMITS ({})", - rem, - Self::MAX_PERMITS - ); - } + if rem > 0 && is_empty { + let permits = rem << Self::PERMIT_SHIFT; + assert!( + permits < Self::MAX_PERMITS, + "cannot add more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release); + assert!( + prev + permits <= Self::MAX_PERMITS, + "number of added permits ({}) would overflow MAX_PERMITS ({})", + rem, + Self::MAX_PERMITS + ); + rem = 0; + } - // Split off the queue at the last waiter that was satisfied, creating a - // new list. Once we release the lock, we'll drain this list and notify - // the waiters in it. - if let Some(waiter) = last { - // Safety: it's only safe to call `split_back` with a pointer to a - // node in the same list as the one we call `split_back` on. Since - // we got the waiter pointer from the list's iterator, this is fine. - unsafe { waiters.queue.split_back(waiter) } - } else { - LinkedList::new() + drop(waiters); // release the lock + + wakers + .iter_mut() + .filter_map(Option::take) + .for_each(Waker::wake); } + + assert_eq!(rem, 0); } fn poll_acquire( @@ -354,18 +353,6 @@ impl fmt::Debug for Semaphore { } } -/// Pop all waiters from `list`, starting at the end of the queue, and notify -/// them. -fn notify_all(mut list: LinkedList<Waiter>) { - while let Some(waiter) = list.pop_back() { - let waker = unsafe { waiter.as_ref().waker.with_mut(|waker| (*waker).take()) }; - - waker - .expect("if a node is in the wait list, it must have a waker") - .wake(); - } -} - impl Waiter { fn new(num_permits: u16) -> Self { Waiter { @@ -471,8 +458,7 @@ impl Drop for Acquire<'_> { let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire); if acquired_permits > 0 { - let notified = self.semaphore.add_permits_locked(acquired_permits, waiters); - notify_all(notified); + self.semaphore.add_permits_locked(acquired_permits, waiters); } } } diff --git a/tokio/src/sync/tests/loom_semaphore_batch.rs b/tokio/src/sync/tests/loom_semaphore_batch.rs index 4c1936c5..76a1bc00 100644 --- a/tokio/src/sync/tests/loom_semaphore_batch.rs +++ b/tokio/src/sync/tests/loom_semaphore_batch.rs @@ -114,6 +114,50 @@ fn concurrent_close() { } #[test] +fn concurrent_cancel() { + async fn poll_and_cancel(semaphore: Arc<Semaphore>) { + let mut acquire1 = Some(semaphore.acquire(1)); + let mut acquire2 = Some(semaphore.acquire(1)); + poll_fn(|cx| { + // poll the acquire future once, and then immediately throw + // it away. this simulates a situation where a future is + // polled and then cancelled, such as by a timeout. + if let Some(acquire) = acquire1.take() { + pin!(acquire); + let _ = acquire.poll(cx); + } + if let Some(acquire) = acquire2.take() { + pin!(acquire); + let _ = acquire.poll(cx); + } + Poll::Ready(()) + }) + .await + } + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(0)); + let t1 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + let t2 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + let t3 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + + t1.join().unwrap(); + semaphore.release(10); + t2.join().unwrap(); + t3.join().unwrap(); + }); +} + +#[test] fn batch() { let mut b = loom::model::Builder::new(); b.preemption_bound = Some(1); diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index 1a488032..aa3ce771 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -164,46 +164,6 @@ impl<T: Link> LinkedList<T> { } } -cfg_sync! { - impl<T: Link> LinkedList<T> { - /// Splits this list off at `node`, returning a new list with `node` at its - /// front. - /// - /// If `node` is at the the front of this list, then this list will be empty after - /// splitting. If `node` is the last node in this list, then the returned - /// list will contain only `node`. - /// - /// # Safety - /// - /// The caller **must** ensure that `node` is currently contained by - /// `self` or not contained by any other list. - pub(crate) unsafe fn split_back(&mut self, node: NonNull<T::Target>) -> Self { - let new_tail = T::pointers(node).as_mut().prev.take().map(|prev| { - T::pointers(prev).as_mut().next = None; - prev - }); - if new_tail.is_none() { - self.head = None; - } - let tail = std::mem::replace(&mut self.tail, new_tail); - Self { - head: Some(node), - tail, - } - } - - /// Takes all entries from this list, returning a new list. - /// - /// This list will be left empty. - pub(crate) fn take_all(&mut self) -> Self { - Self { - head: self.head.take(), - tail: self.tail.take(), - } - } - } -} - impl<T: Link> fmt::Debug for LinkedList<T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("LinkedList") @@ -213,49 +173,41 @@ impl<T: Link> fmt::Debug for LinkedList<T> { } } -// ===== impl Iter ===== - -#[cfg(any(feature = "sync", feature = "rt-threaded"))] -pub(crate) struct Iter<'a, T: Link> { - curr: Option<NonNull<T::Target>>, - #[cfg(feature = "sync")] - curr_back: Option<NonNull<T::Target>>, - _p: core::marker::PhantomData<&'a T>, -} - -#[cfg(any(feature = "sync", feature = "rt-threaded"))] -impl<T: Link> LinkedList<T> { - pub(crate) fn iter(&self) -> Iter<'_, T> { - Iter { - curr: self.head, - #[cfg(feature = "sync")] - curr_back: self.tail, - _p: core::marker::PhantomData, +cfg_sync! { + impl<T: Link> LinkedList<T> { + pub(crate) fn last(&self) -> Option<&T::Target> { + let tail = self.tail.as_ref()?; + unsafe { + Some(&*tail.as_ptr()) + } } } } -#[cfg(any(feature = "sync", feature = "rt-threaded"))] -impl<'a, T: Link> Iterator for Iter<'a, T> { - type Item = &'a T::Target; +// ===== impl Iter ===== - fn next(&mut self) -> Option<&'a T::Target> { - let curr = self.curr?; - // safety: the pointer references data contained by the list - self.curr = unsafe { T::pointers(curr).as_ref() }.next; +cfg_rt_threaded! { + pub(crate) struct Iter<'a, T: Link> { + curr: Option<NonNull<T::Target>>, + _p: core::marker::PhantomData<&'a T>, + } - // safety: the value is still owned by the linked list. - Some(unsafe { &*curr.as_ptr() }) + impl<T: Link> LinkedList<T> { + pub(crate) fn iter(&self) -> Iter<'_, T> { + Iter { + curr: self.head, + _p: core::marker::PhantomData, + } + } } -} -cfg_sync! { - impl<'a, T: Link> DoubleEndedIterator for Iter<'a, T> { - fn next_back(&mut self) -> Option<&'a T::Target> { - let curr = self.curr_back?; + impl<'a, T: Link> Iterator for Iter<'a, T> { + type Item = &'a T::Target; + fn next(&mut self) -> Option<&'a T::Target> { + let curr = self.curr?; // safety: the pointer references data contained by the list - self.curr_back = unsafe { T::pointers(curr).as_ref() }.prev; + self.curr = unsafe { T::pointers(curr).as_ref() }.next; // safety: the value is still owned by the linked list. Some(unsafe { &*curr.as_ptr() }) @@ -564,86 +516,6 @@ mod tests { assert!(i.next().is_none()); } - #[test] - fn split_back() { - let a = entry(1); - let b = entry(2); - let c = entry(3); - let d = entry(4); - - { - let mut list1 = LinkedList::<&Entry>::new(); - - push_all( - &mut list1, - &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()], - ); - let mut list2 = unsafe { list1.split_back(ptr(&a)) }; - - assert_eq!([2, 3, 4].to_vec(), collect_list(&mut list1)); - assert_eq!([1].to_vec(), collect_list(&mut list2)); - } - - { - let mut list1 = LinkedList::<&Entry>::new(); - - push_all( - &mut list1, - &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()], - ); - let mut list2 = unsafe { list1.split_back(ptr(&b)) }; - - assert_eq!([3, 4].to_vec(), collect_list(&mut list1)); - assert_eq!([1, 2].to_vec(), collect_list(&mut list2)); - } - - { - let mut list1 = LinkedList::<&Entry>::new(); - - push_all( - &mut list1, - &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()], - ); - let mut list2 = unsafe { list1.split_back(ptr(&c)) }; - - assert_eq!([4].to_vec(), collect_list(&mut list1)); - assert_eq!([1, 2, 3].to_vec(), collect_list(&mut list2)); - } - - { - let mut list1 = LinkedList::<&Entry>::new(); - - push_all( - &mut list1, - &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()], - ); - let mut list2 = unsafe { list1.split_back(ptr(&d)) }; - - assert_eq!(Vec::<i32>::new(), collect_list(&mut list1)); - assert_eq!([1, 2, 3, 4].to_vec(), collect_list(&mut list2)); - } - } - - #[test] - fn take_all() { - let mut list1 = LinkedList::<&Entry>::new(); - let a = entry(1); - let b = entry(2); - - list1.push_front(a.as_ref()); - list1.push_front(b.as_ref()); - - assert!(!list1.is_empty()); - - let mut list2 = list1.take_all(); - - assert!(list1.is_empty()); - assert!(!list2.is_empty()); - - assert_eq!(Vec::<i32>::new(), collect_list(&mut list1)); - assert_eq!([1, 2].to_vec(), collect_list(&mut list2)); - } - proptest::proptest! { #[test] fn fuzz_linked_list(ops: Vec<usize>) { |