diff options
-rw-r--r-- | tokio/src/sync/mpsc/chan.rs | 8 | ||||
-rw-r--r-- | tokio/src/sync/mutex.rs | 6 | ||||
-rw-r--r-- | tokio/src/sync/semaphore.rs | 8 | ||||
-rw-r--r-- | tokio/src/sync/semaphore_ll.rs | 943 | ||||
-rw-r--r-- | tokio/src/sync/tests/loom_semaphore_ll.rs | 62 | ||||
-rw-r--r-- | tokio/src/sync/tests/semaphore_ll.rs | 394 |
6 files changed, 973 insertions, 448 deletions
diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 7a15e8b3..847a0b70 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -408,7 +408,7 @@ impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { } fn drop_permit(&self, permit: &mut Permit) { - permit.release(&self.0); + permit.release(1, &self.0); } fn add_permit(&self) { @@ -425,17 +425,17 @@ impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { permit: &mut Permit, ) -> Poll<Result<(), ClosedError>> { permit - .poll_acquire(cx, &self.0) + .poll_acquire(cx, 1, &self.0) .map_err(|_| ClosedError::new()) } fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { - permit.try_acquire(&self.0)?; + permit.try_acquire(1, &self.0)?; Ok(()) } fn forget(&self, permit: &mut Self::Permit) { - permit.forget() + permit.forget(1); } fn close(&self) { diff --git a/tokio/src/sync/mutex.rs b/tokio/src/sync/mutex.rs index fe891159..48451357 100644 --- a/tokio/src/sync/mutex.rs +++ b/tokio/src/sync/mutex.rs @@ -156,7 +156,7 @@ impl<T> Mutex<T> { lock: self, permit: semaphore::Permit::new(), }; - poll_fn(|cx| guard.permit.poll_acquire(cx, &self.s)) + poll_fn(|cx| guard.permit.poll_acquire(cx, 1, &self.s)) .await .unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a @@ -169,7 +169,7 @@ impl<T> Mutex<T> { /// Try to acquire the lock pub fn try_lock(&self) -> Result<MutexGuard<'_, T>, TryLockError> { let mut permit = semaphore::Permit::new(); - match permit.try_acquire(&self.s) { + match permit.try_acquire(1, &self.s) { Ok(_) => Ok(MutexGuard { lock: self, permit }), Err(_) => Err(TryLockError(())), } @@ -178,7 +178,7 @@ impl<T> Mutex<T> { impl<'a, T> Drop for MutexGuard<'a, T> { fn drop(&mut self) { - self.permit.release(&self.lock.s); + self.permit.release(1, &self.lock.s); } } diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index 2cfb5d34..13d5cfb2 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -60,7 +60,7 @@ impl Semaphore { sem: &self, ll_permit: ll::Permit::new(), }; - poll_fn(|cx| permit.ll_permit.poll_acquire(cx, &self.ll_sem)) + poll_fn(|cx| permit.ll_permit.poll_acquire(cx, 1, &self.ll_sem)) .await .unwrap(); permit @@ -69,7 +69,7 @@ impl Semaphore { /// Try to acquire a permit form the semaphore pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> { let mut ll_permit = ll::Permit::new(); - match ll_permit.try_acquire(&self.ll_sem) { + match ll_permit.try_acquire(1, &self.ll_sem) { Ok(_) => Ok(SemaphorePermit { sem: self, ll_permit, @@ -84,12 +84,12 @@ impl<'a> SemaphorePermit<'a> { /// This can be used to reduce the amount of permits available from a /// semaphore. pub fn forget(mut self) { - self.ll_permit.forget(); + self.ll_permit.forget(1); } } impl<'a> Drop for SemaphorePermit<'_> { fn drop(&mut self) { - self.ll_permit.release(&self.sem.ll_sem); + self.ll_permit.release(1, &self.sem.ll_sem); } } diff --git a/tokio/src/sync/semaphore_ll.rs b/tokio/src/sync/semaphore_ll.rs index 0ce85838..baed5f0a 100644 --- a/tokio/src/sync/semaphore_ll.rs +++ b/tokio/src/sync/semaphore_ll.rs @@ -10,17 +10,15 @@ //! section. If no permits are available, then acquiring the semaphore returns //! `Pending`. The task is woken once a permit becomes available. -use crate::loom::{ - cell::CausalCell, - future::AtomicWaker, - sync::atomic::{AtomicPtr, AtomicUsize}, - thread, -}; +use crate::loom::cell::CausalCell; +use crate::loom::future::AtomicWaker; +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; +use crate::loom::thread; +use std::cmp; use std::fmt; use std::ptr::{self, NonNull}; use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release}; -use std::sync::Arc; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; use std::usize; @@ -32,13 +30,13 @@ pub(crate) struct Semaphore { state: AtomicUsize, /// waiter queue head pointer. - head: CausalCell<NonNull<WaiterNode>>, + head: CausalCell<NonNull<Waiter>>, /// Coordinates access to the queue head. rx_lock: AtomicUsize, /// Stub waiter node used as part of the MPSC channel algorithm. - stub: Box<WaiterNode>, + stub: Box<Waiter>, } /// A semaphore permit @@ -54,7 +52,7 @@ pub(crate) struct Semaphore { /// before dropping the permit. #[derive(Debug)] pub(crate) struct Permit { - waiter: Option<Arc<WaiterNode>>, + waiter: Option<Box<Waiter>>, state: PermitState, } @@ -64,29 +62,24 @@ pub(crate) struct AcquireError(()); /// Error returned by `Permit::try_acquire`. #[derive(Debug)] -pub(crate) struct TryAcquireError { - pub(crate) kind: ErrorKind, -} - -#[derive(Debug)] -pub(crate) enum ErrorKind { +pub(crate) enum TryAcquireError { Closed, NoPermits, } /// Node used to notify the semaphore waiter when permit is available. #[derive(Debug)] -struct WaiterNode { +struct Waiter { /// Stores waiter state. /// - /// See `NodeState` for more details. + /// See `WaiterState` for more details. state: AtomicUsize, /// Task to wake when a permit is made available. waker: AtomicWaker, /// Next pointer in the queue of waiting senders. - next: AtomicPtr<WaiterNode>, + next: AtomicPtr<Waiter>, } /// Semaphore state @@ -103,46 +96,55 @@ struct WaiterNode { struct SemState(usize); /// Permit state -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone)] enum PermitState { - /// The permit has not been requested. - Idle, - - /// Currently waiting for a permit to be made available and assigned to the + /// Currently waiting for permits to be made available and assigned to the /// waiter. - Waiting, + Waiting(u16), - /// The permit has been acquired. - Acquired, + /// The number of acquired permits + Acquired(u16), } -/// Waiter node state -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -#[repr(usize)] -enum NodeState { - /// Not waiting for a permit and the node is not in the wait queue. - /// - /// This is the initial state. - Idle = 0, +/// State for an individual waker node +#[derive(Debug, Copy, Clone)] +struct WaiterState(usize); - /// Not waiting for a permit but the node is in the wait queue. - /// - /// This happens when the waiter has previously requested a permit, but has - /// since canceled the request. The node cannot be removed by the waiter, so - /// this state informs the receiver to skip the node when it pops it from - /// the wait queue. - Queued = 1, +/// Waiter node is in the semaphore queue +const QUEUED: usize = 0b001; - /// Waiting for a permit and the node is in the wait queue. - QueuedWaiting = 2, +/// Semaphore has been closed, no more permits will be issued. +const CLOSED: usize = 0b10; - /// The waiter has been assigned a permit and the node has been removed from - /// the queue. - Assigned = 3, +/// The permit that owns the `Waiter` dropped. +const DROPPED: usize = 0b100; - /// The semaphore has been closed. No more permits will be issued. - Closed = 4, -} +/// Represents "one requested permit" in the waiter state +const PERMIT_ONE: usize = 0b1000; + +/// Masks the waiter state to only contain bits tracking number of requested +/// permits. +const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1); + +/// How much to shift a permit count to pack it into the waker state +const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros(); + +/// Flag differentiating between available permits and waiter pointers. +/// +/// If we assume pointers are properly aligned, then the least significant bit +/// will always be zero. So, we use that bit to track if the value represents a +/// number. +const NUM_FLAG: usize = 0b01; + +/// Signal the semaphore is closed +const CLOSED_FLAG: usize = 0b10; + +/// Maximum number of permits a semaphore can manage +const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT; + +/// When representing "numbers", the state has to be shifted this much (to get +/// rid of the flag bit). +const NUM_SHIFT: usize = 2; // ===== impl Semaphore ===== @@ -153,8 +155,8 @@ impl Semaphore { /// /// Panics if `permits` is zero. pub(crate) fn new(permits: usize) -> Semaphore { - let stub = Box::new(WaiterNode::new()); - let ptr = NonNull::new(&*stub as *const _ as *mut _).unwrap(); + let stub = Box::new(Waiter::new()); + let ptr = NonNull::from(&*stub); // Allocations are aligned debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0); @@ -171,32 +173,63 @@ impl Semaphore { /// Returns the current number of available permits pub(crate) fn available_permits(&self) -> usize { - let curr = SemState::load(&self.state, Acquire); + let curr = SemState(self.state.load(Acquire)); curr.available_permits() } - /// Poll for a permit - fn poll_permit( + /// Try to acquire the requested number of permits, registering the waiter + /// if not enough permits are available. + fn poll_acquire( &self, - mut permit: Option<(&mut Context<'_>, &mut Permit)>, + cx: &mut Context<'_>, + num_permits: u16, + permit: &mut Permit, ) -> Poll<Result<(), AcquireError>> { + self.poll_acquire2(num_permits, || { + let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new())); + + waiter.waker.register_by_ref(cx.waker()); + + Some(NonNull::from(&**waiter)) + }) + } + + fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { + match self.poll_acquire2(num_permits, || None) { + Poll::Ready(res) => res.map_err(to_try_acquire), + Poll::Pending => Err(TryAcquireError::NoPermits), + } + } + + /// Poll for a permit + /// + /// Tries to acquire available permits first. If unable to acquire a + /// sufficient number of permits, the caller's waiter is pushed onto the + /// semaphore's wait queue. + fn poll_acquire2<F>( + &self, + num_permits: u16, + mut get_waiter: F, + ) -> Poll<Result<(), AcquireError>> + where + F: FnMut() -> Option<NonNull<Waiter>>, + { + let num_permits = num_permits as usize; + // Load the current state - let mut curr = SemState::load(&self.state, Acquire); + let mut curr = SemState(self.state.load(Acquire)); - // Tracks a *mut WaiterNode representing an Arc clone. - // - // This avoids having to bump the ref count unless required. - let mut maybe_strong: Option<NonNull<WaiterNode>> = None; + // Saves a ref to the waiter node + let mut maybe_waiter: Option<NonNull<Waiter>> = None; - macro_rules! undo_strong { + /// Used in branches where we attempt to push the waiter into the wait + /// queue but fail due to permits becoming available or the wait queue + /// transitioning to "closed". In this case, the waiter must be + /// transitioned back to the "idle" state. + macro_rules! revert_to_idle { () => { - if let Some(waiter) = maybe_strong { - // The waiter was cloned, but never got queued. - // Before entering `poll_permit`, the waiter was in the - // `Idle` state. We must transition the node back to the - // idle state. - let waiter = unsafe { Arc::from_raw(waiter.as_ptr()) }; - waiter.revert_to_idle(); + if let Some(waiter) = maybe_waiter { + unsafe { waiter.as_ref() }.revert_to_idle(); } }; } @@ -205,64 +238,82 @@ impl Semaphore { let mut next = curr; if curr.is_closed() { - undo_strong!(); + revert_to_idle!(); return Ready(Err(AcquireError::closed())); } - if !next.acquire_permit(&self.stub) { - debug_assert!(curr.waiter().is_some()); + let acquired = next.acquire_permits(num_permits, &self.stub); - if maybe_strong.is_none() { - if let Some((ref mut cx, ref mut permit)) = permit { - // Get the Sender's waiter node, or initialize one - let waiter = permit - .waiter - .get_or_insert_with(|| Arc::new(WaiterNode::new())); + if !acquired { + // There are not enough available permits to satisfy the + // request. The permit transitions to a waiting state. + debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits); - waiter.register(cx); - - if !waiter.to_queued_waiting() { + if let Some(waiter) = maybe_waiter.as_ref() { + // Safety: the caller owns the waiter. + let w = unsafe { waiter.as_ref() }; + w.set_permits_to_acquire(num_permits - curr.available_permits()); + } else { + // Get the waiter for the permit. + if let Some(waiter) = get_waiter() { + // Safety: the caller owns the waiter. + let w = unsafe { waiter.as_ref() }; + + // If there are any currently available permits, the + // waiter acquires those immediately and waits for the + // remaining permits to become available. + if !w.to_queued(num_permits - curr.available_permits()) { // The node is alrady queued, there is no further work // to do. return Pending; } - maybe_strong = Some(WaiterNode::into_non_null(waiter.clone())); + maybe_waiter = Some(waiter); } else { - // If no `waiter`, then the task is not registered and there - // is no further work to do. + // No waiter, this indicates the caller does not wish to + // "wait", so there is nothing left to do. return Pending; } } - next.set_waiter(maybe_strong.unwrap()); + next.set_waiter(maybe_waiter.unwrap()); } debug_assert_ne!(curr.0, 0); debug_assert_ne!(next.0, 0); - match next.compare_exchange(&self.state, curr, AcqRel, Acquire) { + match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { Ok(_) => { - match curr.waiter() { - Some(prev_waiter) => { - let waiter = maybe_strong.unwrap(); - - // Finish pushing - unsafe { - prev_waiter.as_ref().next.store(waiter.as_ptr(), Release); - } - - return Pending; + if acquired { + // Successfully acquire permits **without** queuing the + // waiter node. The waiter node is not currently in the + // queue. + revert_to_idle!(); + return Ready(Ok(())); + } else { + // The node is pushed into the queue, the final step is + // to set the node's "next" pointer to return the wait + // queue into a consistent state. + + let prev_waiter = + curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub)); + + let waiter = maybe_waiter.unwrap(); + + // Link the nodes. + // + // Safety: the mpsc algorithm guarantees the old tail of + // the queue is not removed from the queue during the + // push process. + unsafe { + prev_waiter.as_ref().store_next(waiter); } - None => { - undo_strong!(); - return Ready(Ok(())); - } + return Pending; } } Err(actual) => { - curr = actual; + curr = SemState(actual); } } } @@ -332,27 +383,13 @@ impl Semaphore { /// This function is called by `add_permits` after the add lock has been /// acquired. fn add_permits_locked2(&self, mut n: usize, closed: bool) { - while n > 0 || closed { - let waiter = match self.pop(n, closed) { - Some(waiter) => waiter, - None => { - return; - } - }; - - if waiter.notify(closed) { - n = n.saturating_sub(1); - } + // If closing the semaphore, we want to drain the entire queue. The + // number of permits being assigned doesn't matter. + if closed { + n = usize::MAX; } - } - /// Pop a waiter - /// - /// `rem` represents the remaining number of times the caller will pop. If - /// there are no more waiters to pop, `rem` is used to set the available - /// permits. - fn pop(&self, rem: usize, closed: bool) -> Option<Arc<WaiterNode>> { - 'outer: loop { + 'outer: while n > 0 { unsafe { let mut head = self.head.with(|head| *head); let mut next_ptr = head.as_ref().next.load(Acquire); @@ -360,12 +397,14 @@ impl Semaphore { let stub = self.stub(); if head == stub { + // The stub node indicates an empty queue. Any remaining + // permits get assigned back to the semaphore. let next = match NonNull::new(next_ptr) { Some(next) => next, None => { // This loop is not part of the standard intrusive mpsc // channel algorithm. This is where we atomically pop - // the last task and add `rem` to the remaining capacity. + // the last task and add `n` to the remaining capacity. // // This modification to the pop algorithm works because, // at this point, we have not done any work (only done @@ -384,27 +423,26 @@ impl Semaphore { loop { if curr.has_waiter(&self.stub) { - // Inconsistent + // A waiter is being added concurrently. + // This is the MPSC queue's "inconsistent" + // state and we must loop and try again. thread::yield_now(); continue 'outer; } - // When closing the semaphore, nodes are popped - // with `rem == 0`. In this case, we are not - // adding permits, but notifying waiters of the - // semaphore's closed state. - if rem == 0 { + // If closing, nothing more to do. + if closed { debug_assert!(curr.is_closed(), "state = {:?}", curr); - return None; + return; } let mut next = curr; - next.release_permits(rem, &self.stub); + next.release_permits(n, &self.stub); - match next.compare_exchange(&self.state, curr, AcqRel, Acquire) { - Ok(_) => return None, + match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => return, Err(actual) => { - curr = actual; + curr = SemState(actual); } } } @@ -416,10 +454,19 @@ impl Semaphore { next_ptr = next.as_ref().next.load(Acquire); } + // `head` points to a waiter assign permits to the waiter. If + // all requested permits are satisfied, then we can continue, + // otherwise the node stays in the wait queue. + if !head.as_ref().assign_permits(&mut n, closed) { + assert_eq!(n, 0); + return; + } + if let Some(next) = NonNull::new(next_ptr) { self.head.with_mut(|head| *head = next); - return Some(Arc::from_raw(head.as_ptr())); + self.remove_queued(head, closed); + continue 'outer; } let state = SemState::load(&self.state, Acquire); @@ -440,7 +487,8 @@ impl Semaphore { if let Some(next) = NonNull::new(next_ptr) { self.head.with_mut(|head| *head = next); - return Some(Arc::from_raw(head.as_ptr())); + self.remove_queued(head, closed); + continue 'outer; } // Inconsistent state, loop @@ -449,37 +497,89 @@ impl Semaphore { } } + /// The wait node has had all of its permits assigned and has been removed + /// from the wait queue. + /// + /// Attempt to remove the QUEUED bit from the node. If additional permits + /// are concurrently requested, the node must be pushed back into the wait + /// queued. + fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) { + let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire)); + + loop { + if curr.is_dropped() { + // The Permit dropped, it is on us to release the memory + let _ = unsafe { Box::from_raw(waiter.as_ptr()) }; + return; + } + + // The node is removed from the queue. We attempt to unset the + // queued bit, but concurrently the waiter has requested more + // permits. When the waiter requested more permits, it saw the + // queued bit set so took no further action. This requires us to + // push the node back into the queue. + if curr.permits_to_acquire() > 0 { + // More permits are requested. The waiter must be re-queued + unsafe { + self.push_waiter(waiter, closed); + } + return; + } + + let mut next = curr; + next.unset_queued(); + + let w = unsafe { waiter.as_ref() }; + + match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { + Ok(_) => return, + Err(actual) => { + curr = WaiterState(actual); + } + } + } + } + unsafe fn push_stub(&self, closed: bool) { - let stub = self.stub(); + self.push_waiter(self.stub(), closed); + } + unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) { // Set the next pointer. This does not require an atomic operation as // this node is not accessible. The write will be flushed with the next // operation - stub.as_ref().next.store(ptr::null_mut(), Relaxed); + waiter.as_ref().next.store(ptr::null_mut(), Relaxed); // Update the tail to point to the new node. We need to see the previous // node in order to update the next pointer as well as release `task` // to any other threads calling `push`. - let prev = SemState::new_ptr(stub, closed).swap(&self.state, AcqRel); + let next = SemState::new_ptr(waiter, closed); + let prev = SemState(self.state.swap(next.0, AcqRel)); debug_assert_eq!(closed, prev.is_closed()); - // The stub is only pushed when there are pending tasks. Because of + // This function is only called when there are pending tasks. Because of // this, the state must *always* be in pointer mode. let prev = prev.waiter().unwrap(); - // We don't want the *existing* pointer to be a stub. - debug_assert_ne!(prev, stub); + // No cycles plz + debug_assert_ne!(prev, waiter); // Release `task` to the consume end. - prev.as_ref().next.store(stub.as_ptr(), Release); + prev.as_ref().next.store(waiter.as_ptr(), Release); } - fn stub(&self) -> NonNull<WaiterNode> { + fn stub(&self) -> NonNull<Waiter> { unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) } } } +impl Drop for Semaphore { + fn drop(&mut self) { + self.close(); + } +} + impl fmt::Debug for Semaphore { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Semaphore") @@ -501,15 +601,20 @@ impl Permit { /// /// The permit begins in the "unacquired" state. pub(crate) fn new() -> Permit { + use PermitState::Acquired; + Permit { waiter: None, - state: PermitState::Idle, + state: Acquired(0), } } /// Returns true if the permit has been acquired pub(crate) fn is_acquired(&self) -> bool { - self.state == PermitState::Acquired + match self.state { + PermitState::Acquired(num) if num > 0 => true, + _ => false, + } } /// Try to acquire the permit. If no permits are available, the current task @@ -517,70 +622,127 @@ impl Permit { pub(crate) fn poll_acquire( &mut self, cx: &mut Context<'_>, + num_permits: u16, semaphore: &Semaphore, ) -> Poll<Result<(), AcquireError>> { + use std::cmp::Ordering::*; + use PermitState::*; + match self.state { - PermitState::Idle => {} - PermitState::Waiting => { + Waiting(requested) => { + // There must be a waiter let waiter = self.waiter.as_ref().unwrap(); - if waiter.acquire(cx)? { - self.state = PermitState::Acquired; + match requested.cmp(&num_permits) { + Less => { + let delta = num_permits - requested; + + // Request additional permits. If the waiter has been + // dequeued, it must be re-queued. + if !waiter.try_inc_permits_to_acquire(delta as usize) { + let waiter = NonNull::from(&**waiter); + + // Ignore the result. The check for + // `permits_to_acquire()` will converge the state as + // needed + let _ = semaphore.poll_acquire2(delta, || Some(waiter))?; + } + + self.state = Waiting(num_permits); + } + Greater => { + let delta = requested - num_permits; + let to_release = waiter.try_dec_permits_to_acquire(delta as usize); + + semaphore.add_permits(to_release); + self.state = Waiting(num_permits); + } + Equal => {} + } + + if waiter.permits_to_acquire()? == 0 { + self.state = Acquired(requested); + return Ready(Ok(())); + } + + waiter.waker.register_by_ref(cx.waker()); + + if waiter.permits_to_acquire()? == 0 { + self.state = Acquired(requested); return Ready(Ok(())); - } else { - return Pending; } - } - PermitState::Acquired => { - return Ready(Ok(())); - } - } - match semaphore.poll_permit(Some((cx, self)))? { - Ready(()) => { - self.state = PermitState::Acquired; - Ready(Ok(())) - } - Pending => { - self.state = PermitState::Waiting; Pending } + Acquired(acquired) => { + if acquired >= num_permits { + Ready(Ok(())) + } else { + match semaphore.poll_acquire(cx, num_permits - acquired, self)? { + Ready(()) => { + self.state = Acquired(num_permits); + Ready(Ok(())) + } + Pending => { + self.state = Waiting(num_permits); + Pending + } + } + } + } } } /// Try to acquire the permit. - pub(crate) fn try_acquire(&mut self, semaphore: &Semaphore) -> Result<(), TryAcquireError> { + pub(crate) fn try_acquire( + &mut self, + num_permits: u16, + semaphore: &Semaphore, + ) -> Result<(), TryAcquireError> { + use PermitState::*; + match self.state { - PermitState::Idle => {} - PermitState::Waiting => { + Waiting(requested) => { + // There must be a waiter let waiter = self.waiter.as_ref().unwrap(); - if waiter.acquire2().map_err(to_try_acquire)? { - self.state = PermitState::Acquired; - return Ok(()); + if requested > num_permits { + let delta = requested - num_permits; + let to_release = waiter.try_dec_permits_to_acquire(delta as usize); + + semaphore.add_permits(to_release); + self.state = Waiting(num_permits); + } + + let res = waiter.permits_to_acquire().map_err(to_try_acquire)?; + + if res == 0 { + if requested < num_permits { + // Try to acquire the additional permits + semaphore.try_acquire(num_permits - requested)?; + } + + self.state = Acquired(num_permits); + Ok(()) } else { - return Err(TryAcquireError::no_permits()); + Err(TryAcquireError::NoPermits) } } - PermitState::Acquired => { - return Ok(()); - } - } + Acquired(acquired) => { + if acquired < num_permits { + semaphore.try_acquire(num_permits - acquired)?; + self.state = Acquired(num_permits); + } - match semaphore.poll_permit(None).map_err(to_try_acquire)? { - Ready(()) => { - self.state = PermitState::Acquired; Ok(()) } - Pending => Err(TryAcquireError::no_permits()), } } /// Release a permit back to the semaphore - pub(crate) fn release(&mut self, semaphore: &Semaphore) { - if self.forget2() { - semaphore.add_permits(1); - } + pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) { + let n = self.forget(n); + semaphore.add_permits(n as usize); } /// Forget the permit **without** releasing it back to the semaphore. @@ -590,22 +752,37 @@ impl Permit { /// /// Repeatedly calling `forget` without associated calls to `add_permit` /// will result in the semaphore losing all permits. - pub(crate) fn forget(&mut self) { - self.forget2(); - } + /// + /// Will forget **at most** the number of acquired permits. This number is + /// returned. + pub(crate) fn forget(&mut self, n: u16) -> u16 { + use PermitState::*; - /// Returns `true` if the permit was acquired - fn forget2(&mut self) -> bool { match self.state { - PermitState::Idle => false, - PermitState::Waiting => { - let ret = self.waiter.as_ref().unwrap().cancel_interest(); - self.state = PermitState::Idle; - ret + Waiting(requested) => { + let n = cmp::min(n, requested); + + // Decrement + let acquired = self + .waiter + .as_ref() + .unwrap() + .try_dec_permits_to_acquire(n as usize) as u16; + + if n == requested { + self.state = Acquired(0); + } else if acquired == requested - n { + self.state = Waiting(acquired); + } else { + self.state = Waiting(requested - n); + } + + acquired } - PermitState::Acquired => { - self.state = PermitState::Idle; - true + Acquired(acquired) => { + let n = cmp::min(n, acquired); + self.state = Acquired(acquired - n); + n } } } @@ -617,6 +794,20 @@ impl Default for Permit { } } +impl Drop for Permit { + fn drop(&mut self) { + if let Some(waiter) = self.waiter.take() { + // Set the dropped flag + let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel)); + + if state.is_queued() { + // The waiter is stored in the queue. The semaphore will drop it + std::mem::forget(waiter); + } + } + } +} + // ===== impl AcquireError ==== impl AcquireError { @@ -626,7 +817,7 @@ impl AcquireError { } fn to_try_acquire(_: AcquireError) -> TryAcquireError { - TryAcquireError::closed() + TryAcquireError::Closed } impl fmt::Displa |