summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--tokio/src/sync/mpsc/chan.rs8
-rw-r--r--tokio/src/sync/mutex.rs6
-rw-r--r--tokio/src/sync/semaphore.rs8
-rw-r--r--tokio/src/sync/semaphore_ll.rs943
-rw-r--r--tokio/src/sync/tests/loom_semaphore_ll.rs62
-rw-r--r--tokio/src/sync/tests/semaphore_ll.rs394
6 files changed, 973 insertions, 448 deletions
diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs
index 7a15e8b3..847a0b70 100644
--- a/tokio/src/sync/mpsc/chan.rs
+++ b/tokio/src/sync/mpsc/chan.rs
@@ -408,7 +408,7 @@ impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) {
}
fn drop_permit(&self, permit: &mut Permit) {
- permit.release(&self.0);
+ permit.release(1, &self.0);
}
fn add_permit(&self) {
@@ -425,17 +425,17 @@ impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) {
permit: &mut Permit,
) -> Poll<Result<(), ClosedError>> {
permit
- .poll_acquire(cx, &self.0)
+ .poll_acquire(cx, 1, &self.0)
.map_err(|_| ClosedError::new())
}
fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> {
- permit.try_acquire(&self.0)?;
+ permit.try_acquire(1, &self.0)?;
Ok(())
}
fn forget(&self, permit: &mut Self::Permit) {
- permit.forget()
+ permit.forget(1);
}
fn close(&self) {
diff --git a/tokio/src/sync/mutex.rs b/tokio/src/sync/mutex.rs
index fe891159..48451357 100644
--- a/tokio/src/sync/mutex.rs
+++ b/tokio/src/sync/mutex.rs
@@ -156,7 +156,7 @@ impl<T> Mutex<T> {
lock: self,
permit: semaphore::Permit::new(),
};
- poll_fn(|cx| guard.permit.poll_acquire(cx, &self.s))
+ poll_fn(|cx| guard.permit.poll_acquire(cx, 1, &self.s))
.await
.unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
@@ -169,7 +169,7 @@ impl<T> Mutex<T> {
/// Try to acquire the lock
pub fn try_lock(&self) -> Result<MutexGuard<'_, T>, TryLockError> {
let mut permit = semaphore::Permit::new();
- match permit.try_acquire(&self.s) {
+ match permit.try_acquire(1, &self.s) {
Ok(_) => Ok(MutexGuard { lock: self, permit }),
Err(_) => Err(TryLockError(())),
}
@@ -178,7 +178,7 @@ impl<T> Mutex<T> {
impl<'a, T> Drop for MutexGuard<'a, T> {
fn drop(&mut self) {
- self.permit.release(&self.lock.s);
+ self.permit.release(1, &self.lock.s);
}
}
diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs
index 2cfb5d34..13d5cfb2 100644
--- a/tokio/src/sync/semaphore.rs
+++ b/tokio/src/sync/semaphore.rs
@@ -60,7 +60,7 @@ impl Semaphore {
sem: &self,
ll_permit: ll::Permit::new(),
};
- poll_fn(|cx| permit.ll_permit.poll_acquire(cx, &self.ll_sem))
+ poll_fn(|cx| permit.ll_permit.poll_acquire(cx, 1, &self.ll_sem))
.await
.unwrap();
permit
@@ -69,7 +69,7 @@ impl Semaphore {
/// Try to acquire a permit form the semaphore
pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> {
let mut ll_permit = ll::Permit::new();
- match ll_permit.try_acquire(&self.ll_sem) {
+ match ll_permit.try_acquire(1, &self.ll_sem) {
Ok(_) => Ok(SemaphorePermit {
sem: self,
ll_permit,
@@ -84,12 +84,12 @@ impl<'a> SemaphorePermit<'a> {
/// This can be used to reduce the amount of permits available from a
/// semaphore.
pub fn forget(mut self) {
- self.ll_permit.forget();
+ self.ll_permit.forget(1);
}
}
impl<'a> Drop for SemaphorePermit<'_> {
fn drop(&mut self) {
- self.ll_permit.release(&self.sem.ll_sem);
+ self.ll_permit.release(1, &self.sem.ll_sem);
}
}
diff --git a/tokio/src/sync/semaphore_ll.rs b/tokio/src/sync/semaphore_ll.rs
index 0ce85838..baed5f0a 100644
--- a/tokio/src/sync/semaphore_ll.rs
+++ b/tokio/src/sync/semaphore_ll.rs
@@ -10,17 +10,15 @@
//! section. If no permits are available, then acquiring the semaphore returns
//! `Pending`. The task is woken once a permit becomes available.
-use crate::loom::{
- cell::CausalCell,
- future::AtomicWaker,
- sync::atomic::{AtomicPtr, AtomicUsize},
- thread,
-};
+use crate::loom::cell::CausalCell;
+use crate::loom::future::AtomicWaker;
+use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
+use crate::loom::thread;
+use std::cmp;
use std::fmt;
use std::ptr::{self, NonNull};
use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
-use std::sync::Arc;
use std::task::Poll::{Pending, Ready};
use std::task::{Context, Poll};
use std::usize;
@@ -32,13 +30,13 @@ pub(crate) struct Semaphore {
state: AtomicUsize,
/// waiter queue head pointer.
- head: CausalCell<NonNull<WaiterNode>>,
+ head: CausalCell<NonNull<Waiter>>,
/// Coordinates access to the queue head.
rx_lock: AtomicUsize,
/// Stub waiter node used as part of the MPSC channel algorithm.
- stub: Box<WaiterNode>,
+ stub: Box<Waiter>,
}
/// A semaphore permit
@@ -54,7 +52,7 @@ pub(crate) struct Semaphore {
/// before dropping the permit.
#[derive(Debug)]
pub(crate) struct Permit {
- waiter: Option<Arc<WaiterNode>>,
+ waiter: Option<Box<Waiter>>,
state: PermitState,
}
@@ -64,29 +62,24 @@ pub(crate) struct AcquireError(());
/// Error returned by `Permit::try_acquire`.
#[derive(Debug)]
-pub(crate) struct TryAcquireError {
- pub(crate) kind: ErrorKind,
-}
-
-#[derive(Debug)]
-pub(crate) enum ErrorKind {
+pub(crate) enum TryAcquireError {
Closed,
NoPermits,
}
/// Node used to notify the semaphore waiter when permit is available.
#[derive(Debug)]
-struct WaiterNode {
+struct Waiter {
/// Stores waiter state.
///
- /// See `NodeState` for more details.
+ /// See `WaiterState` for more details.
state: AtomicUsize,
/// Task to wake when a permit is made available.
waker: AtomicWaker,
/// Next pointer in the queue of waiting senders.
- next: AtomicPtr<WaiterNode>,
+ next: AtomicPtr<Waiter>,
}
/// Semaphore state
@@ -103,46 +96,55 @@ struct WaiterNode {
struct SemState(usize);
/// Permit state
-#[derive(Debug, Copy, Clone, Eq, PartialEq)]
+#[derive(Debug, Copy, Clone)]
enum PermitState {
- /// The permit has not been requested.
- Idle,
-
- /// Currently waiting for a permit to be made available and assigned to the
+ /// Currently waiting for permits to be made available and assigned to the
/// waiter.
- Waiting,
+ Waiting(u16),
- /// The permit has been acquired.
- Acquired,
+ /// The number of acquired permits
+ Acquired(u16),
}
-/// Waiter node state
-#[derive(Debug, Copy, Clone, Eq, PartialEq)]
-#[repr(usize)]
-enum NodeState {
- /// Not waiting for a permit and the node is not in the wait queue.
- ///
- /// This is the initial state.
- Idle = 0,
+/// State for an individual waker node
+#[derive(Debug, Copy, Clone)]
+struct WaiterState(usize);
- /// Not waiting for a permit but the node is in the wait queue.
- ///
- /// This happens when the waiter has previously requested a permit, but has
- /// since canceled the request. The node cannot be removed by the waiter, so
- /// this state informs the receiver to skip the node when it pops it from
- /// the wait queue.
- Queued = 1,
+/// Waiter node is in the semaphore queue
+const QUEUED: usize = 0b001;
- /// Waiting for a permit and the node is in the wait queue.
- QueuedWaiting = 2,
+/// Semaphore has been closed, no more permits will be issued.
+const CLOSED: usize = 0b10;
- /// The waiter has been assigned a permit and the node has been removed from
- /// the queue.
- Assigned = 3,
+/// The permit that owns the `Waiter` dropped.
+const DROPPED: usize = 0b100;
- /// The semaphore has been closed. No more permits will be issued.
- Closed = 4,
-}
+/// Represents "one requested permit" in the waiter state
+const PERMIT_ONE: usize = 0b1000;
+
+/// Masks the waiter state to only contain bits tracking number of requested
+/// permits.
+const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1);
+
+/// How much to shift a permit count to pack it into the waker state
+const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros();
+
+/// Flag differentiating between available permits and waiter pointers.
+///
+/// If we assume pointers are properly aligned, then the least significant bit
+/// will always be zero. So, we use that bit to track if the value represents a
+/// number.
+const NUM_FLAG: usize = 0b01;
+
+/// Signal the semaphore is closed
+const CLOSED_FLAG: usize = 0b10;
+
+/// Maximum number of permits a semaphore can manage
+const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
+
+/// When representing "numbers", the state has to be shifted this much (to get
+/// rid of the flag bit).
+const NUM_SHIFT: usize = 2;
// ===== impl Semaphore =====
@@ -153,8 +155,8 @@ impl Semaphore {
///
/// Panics if `permits` is zero.
pub(crate) fn new(permits: usize) -> Semaphore {
- let stub = Box::new(WaiterNode::new());
- let ptr = NonNull::new(&*stub as *const _ as *mut _).unwrap();
+ let stub = Box::new(Waiter::new());
+ let ptr = NonNull::from(&*stub);
// Allocations are aligned
debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
@@ -171,32 +173,63 @@ impl Semaphore {
/// Returns the current number of available permits
pub(crate) fn available_permits(&self) -> usize {
- let curr = SemState::load(&self.state, Acquire);
+ let curr = SemState(self.state.load(Acquire));
curr.available_permits()
}
- /// Poll for a permit
- fn poll_permit(
+ /// Try to acquire the requested number of permits, registering the waiter
+ /// if not enough permits are available.
+ fn poll_acquire(
&self,
- mut permit: Option<(&mut Context<'_>, &mut Permit)>,
+ cx: &mut Context<'_>,
+ num_permits: u16,
+ permit: &mut Permit,
) -> Poll<Result<(), AcquireError>> {
+ self.poll_acquire2(num_permits, || {
+ let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new()));
+
+ waiter.waker.register_by_ref(cx.waker());
+
+ Some(NonNull::from(&**waiter))
+ })
+ }
+
+ fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
+ match self.poll_acquire2(num_permits, || None) {
+ Poll::Ready(res) => res.map_err(to_try_acquire),
+ Poll::Pending => Err(TryAcquireError::NoPermits),
+ }
+ }
+
+ /// Poll for a permit
+ ///
+ /// Tries to acquire available permits first. If unable to acquire a
+ /// sufficient number of permits, the caller's waiter is pushed onto the
+ /// semaphore's wait queue.
+ fn poll_acquire2<F>(
+ &self,
+ num_permits: u16,
+ mut get_waiter: F,
+ ) -> Poll<Result<(), AcquireError>>
+ where
+ F: FnMut() -> Option<NonNull<Waiter>>,
+ {
+ let num_permits = num_permits as usize;
+
// Load the current state
- let mut curr = SemState::load(&self.state, Acquire);
+ let mut curr = SemState(self.state.load(Acquire));
- // Tracks a *mut WaiterNode representing an Arc clone.
- //
- // This avoids having to bump the ref count unless required.
- let mut maybe_strong: Option<NonNull<WaiterNode>> = None;
+ // Saves a ref to the waiter node
+ let mut maybe_waiter: Option<NonNull<Waiter>> = None;
- macro_rules! undo_strong {
+ /// Used in branches where we attempt to push the waiter into the wait
+ /// queue but fail due to permits becoming available or the wait queue
+ /// transitioning to "closed". In this case, the waiter must be
+ /// transitioned back to the "idle" state.
+ macro_rules! revert_to_idle {
() => {
- if let Some(waiter) = maybe_strong {
- // The waiter was cloned, but never got queued.
- // Before entering `poll_permit`, the waiter was in the
- // `Idle` state. We must transition the node back to the
- // idle state.
- let waiter = unsafe { Arc::from_raw(waiter.as_ptr()) };
- waiter.revert_to_idle();
+ if let Some(waiter) = maybe_waiter {
+ unsafe { waiter.as_ref() }.revert_to_idle();
}
};
}
@@ -205,64 +238,82 @@ impl Semaphore {
let mut next = curr;
if curr.is_closed() {
- undo_strong!();
+ revert_to_idle!();
return Ready(Err(AcquireError::closed()));
}
- if !next.acquire_permit(&self.stub) {
- debug_assert!(curr.waiter().is_some());
+ let acquired = next.acquire_permits(num_permits, &self.stub);
- if maybe_strong.is_none() {
- if let Some((ref mut cx, ref mut permit)) = permit {
- // Get the Sender's waiter node, or initialize one
- let waiter = permit
- .waiter
- .get_or_insert_with(|| Arc::new(WaiterNode::new()));
+ if !acquired {
+ // There are not enough available permits to satisfy the
+ // request. The permit transitions to a waiting state.
+ debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits);
- waiter.register(cx);
-
- if !waiter.to_queued_waiting() {
+ if let Some(waiter) = maybe_waiter.as_ref() {
+ // Safety: the caller owns the waiter.
+ let w = unsafe { waiter.as_ref() };
+ w.set_permits_to_acquire(num_permits - curr.available_permits());
+ } else {
+ // Get the waiter for the permit.
+ if let Some(waiter) = get_waiter() {
+ // Safety: the caller owns the waiter.
+ let w = unsafe { waiter.as_ref() };
+
+ // If there are any currently available permits, the
+ // waiter acquires those immediately and waits for the
+ // remaining permits to become available.
+ if !w.to_queued(num_permits - curr.available_permits()) {
// The node is alrady queued, there is no further work
// to do.
return Pending;
}
- maybe_strong = Some(WaiterNode::into_non_null(waiter.clone()));
+ maybe_waiter = Some(waiter);
} else {
- // If no `waiter`, then the task is not registered and there
- // is no further work to do.
+ // No waiter, this indicates the caller does not wish to
+ // "wait", so there is nothing left to do.
return Pending;
}
}
- next.set_waiter(maybe_strong.unwrap());
+ next.set_waiter(maybe_waiter.unwrap());
}
debug_assert_ne!(curr.0, 0);
debug_assert_ne!(next.0, 0);
- match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
+ match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
Ok(_) => {
- match curr.waiter() {
- Some(prev_waiter) => {
- let waiter = maybe_strong.unwrap();
-
- // Finish pushing
- unsafe {
- prev_waiter.as_ref().next.store(waiter.as_ptr(), Release);
- }
-
- return Pending;
+ if acquired {
+ // Successfully acquire permits **without** queuing the
+ // waiter node. The waiter node is not currently in the
+ // queue.
+ revert_to_idle!();
+ return Ready(Ok(()));
+ } else {
+ // The node is pushed into the queue, the final step is
+ // to set the node's "next" pointer to return the wait
+ // queue into a consistent state.
+
+ let prev_waiter =
+ curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub));
+
+ let waiter = maybe_waiter.unwrap();
+
+ // Link the nodes.
+ //
+ // Safety: the mpsc algorithm guarantees the old tail of
+ // the queue is not removed from the queue during the
+ // push process.
+ unsafe {
+ prev_waiter.as_ref().store_next(waiter);
}
- None => {
- undo_strong!();
- return Ready(Ok(()));
- }
+ return Pending;
}
}
Err(actual) => {
- curr = actual;
+ curr = SemState(actual);
}
}
}
@@ -332,27 +383,13 @@ impl Semaphore {
/// This function is called by `add_permits` after the add lock has been
/// acquired.
fn add_permits_locked2(&self, mut n: usize, closed: bool) {
- while n > 0 || closed {
- let waiter = match self.pop(n, closed) {
- Some(waiter) => waiter,
- None => {
- return;
- }
- };
-
- if waiter.notify(closed) {
- n = n.saturating_sub(1);
- }
+ // If closing the semaphore, we want to drain the entire queue. The
+ // number of permits being assigned doesn't matter.
+ if closed {
+ n = usize::MAX;
}
- }
- /// Pop a waiter
- ///
- /// `rem` represents the remaining number of times the caller will pop. If
- /// there are no more waiters to pop, `rem` is used to set the available
- /// permits.
- fn pop(&self, rem: usize, closed: bool) -> Option<Arc<WaiterNode>> {
- 'outer: loop {
+ 'outer: while n > 0 {
unsafe {
let mut head = self.head.with(|head| *head);
let mut next_ptr = head.as_ref().next.load(Acquire);
@@ -360,12 +397,14 @@ impl Semaphore {
let stub = self.stub();
if head == stub {
+ // The stub node indicates an empty queue. Any remaining
+ // permits get assigned back to the semaphore.
let next = match NonNull::new(next_ptr) {
Some(next) => next,
None => {
// This loop is not part of the standard intrusive mpsc
// channel algorithm. This is where we atomically pop
- // the last task and add `rem` to the remaining capacity.
+ // the last task and add `n` to the remaining capacity.
//
// This modification to the pop algorithm works because,
// at this point, we have not done any work (only done
@@ -384,27 +423,26 @@ impl Semaphore {
loop {
if curr.has_waiter(&self.stub) {
- // Inconsistent
+ // A waiter is being added concurrently.
+ // This is the MPSC queue's "inconsistent"
+ // state and we must loop and try again.
thread::yield_now();
continue 'outer;
}
- // When closing the semaphore, nodes are popped
- // with `rem == 0`. In this case, we are not
- // adding permits, but notifying waiters of the
- // semaphore's closed state.
- if rem == 0 {
+ // If closing, nothing more to do.
+ if closed {
debug_assert!(curr.is_closed(), "state = {:?}", curr);
- return None;
+ return;
}
let mut next = curr;
- next.release_permits(rem, &self.stub);
+ next.release_permits(n, &self.stub);
- match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
- Ok(_) => return None,
+ match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
+ Ok(_) => return,
Err(actual) => {
- curr = actual;
+ curr = SemState(actual);
}
}
}
@@ -416,10 +454,19 @@ impl Semaphore {
next_ptr = next.as_ref().next.load(Acquire);
}
+ // `head` points to a waiter assign permits to the waiter. If
+ // all requested permits are satisfied, then we can continue,
+ // otherwise the node stays in the wait queue.
+ if !head.as_ref().assign_permits(&mut n, closed) {
+ assert_eq!(n, 0);
+ return;
+ }
+
if let Some(next) = NonNull::new(next_ptr) {
self.head.with_mut(|head| *head = next);
- return Some(Arc::from_raw(head.as_ptr()));
+ self.remove_queued(head, closed);
+ continue 'outer;
}
let state = SemState::load(&self.state, Acquire);
@@ -440,7 +487,8 @@ impl Semaphore {
if let Some(next) = NonNull::new(next_ptr) {
self.head.with_mut(|head| *head = next);
- return Some(Arc::from_raw(head.as_ptr()));
+ self.remove_queued(head, closed);
+ continue 'outer;
}
// Inconsistent state, loop
@@ -449,37 +497,89 @@ impl Semaphore {
}
}
+ /// The wait node has had all of its permits assigned and has been removed
+ /// from the wait queue.
+ ///
+ /// Attempt to remove the QUEUED bit from the node. If additional permits
+ /// are concurrently requested, the node must be pushed back into the wait
+ /// queued.
+ fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) {
+ let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire));
+
+ loop {
+ if curr.is_dropped() {
+ // The Permit dropped, it is on us to release the memory
+ let _ = unsafe { Box::from_raw(waiter.as_ptr()) };
+ return;
+ }
+
+ // The node is removed from the queue. We attempt to unset the
+ // queued bit, but concurrently the waiter has requested more
+ // permits. When the waiter requested more permits, it saw the
+ // queued bit set so took no further action. This requires us to
+ // push the node back into the queue.
+ if curr.permits_to_acquire() > 0 {
+ // More permits are requested. The waiter must be re-queued
+ unsafe {
+ self.push_waiter(waiter, closed);
+ }
+ return;
+ }
+
+ let mut next = curr;
+ next.unset_queued();
+
+ let w = unsafe { waiter.as_ref() };
+
+ match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
+ Ok(_) => return,
+ Err(actual) => {
+ curr = WaiterState(actual);
+ }
+ }
+ }
+ }
+
unsafe fn push_stub(&self, closed: bool) {
- let stub = self.stub();
+ self.push_waiter(self.stub(), closed);
+ }
+ unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) {
// Set the next pointer. This does not require an atomic operation as
// this node is not accessible. The write will be flushed with the next
// operation
- stub.as_ref().next.store(ptr::null_mut(), Relaxed);
+ waiter.as_ref().next.store(ptr::null_mut(), Relaxed);
// Update the tail to point to the new node. We need to see the previous
// node in order to update the next pointer as well as release `task`
// to any other threads calling `push`.
- let prev = SemState::new_ptr(stub, closed).swap(&self.state, AcqRel);
+ let next = SemState::new_ptr(waiter, closed);
+ let prev = SemState(self.state.swap(next.0, AcqRel));
debug_assert_eq!(closed, prev.is_closed());
- // The stub is only pushed when there are pending tasks. Because of
+ // This function is only called when there are pending tasks. Because of
// this, the state must *always* be in pointer mode.
let prev = prev.waiter().unwrap();
- // We don't want the *existing* pointer to be a stub.
- debug_assert_ne!(prev, stub);
+ // No cycles plz
+ debug_assert_ne!(prev, waiter);
// Release `task` to the consume end.
- prev.as_ref().next.store(stub.as_ptr(), Release);
+ prev.as_ref().next.store(waiter.as_ptr(), Release);
}
- fn stub(&self) -> NonNull<WaiterNode> {
+ fn stub(&self) -> NonNull<Waiter> {
unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
}
}
+impl Drop for Semaphore {
+ fn drop(&mut self) {
+ self.close();
+ }
+}
+
impl fmt::Debug for Semaphore {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Semaphore")
@@ -501,15 +601,20 @@ impl Permit {
///
/// The permit begins in the "unacquired" state.
pub(crate) fn new() -> Permit {
+ use PermitState::Acquired;
+
Permit {
waiter: None,
- state: PermitState::Idle,
+ state: Acquired(0),
}
}
/// Returns true if the permit has been acquired
pub(crate) fn is_acquired(&self) -> bool {
- self.state == PermitState::Acquired
+ match self.state {
+ PermitState::Acquired(num) if num > 0 => true,
+ _ => false,
+ }
}
/// Try to acquire the permit. If no permits are available, the current task
@@ -517,70 +622,127 @@ impl Permit {
pub(crate) fn poll_acquire(
&mut self,
cx: &mut Context<'_>,
+ num_permits: u16,
semaphore: &Semaphore,
) -> Poll<Result<(), AcquireError>> {
+ use std::cmp::Ordering::*;
+ use PermitState::*;
+
match self.state {
- PermitState::Idle => {}
- PermitState::Waiting => {
+ Waiting(requested) => {
+ // There must be a waiter
let waiter = self.waiter.as_ref().unwrap();
- if waiter.acquire(cx)? {
- self.state = PermitState::Acquired;
+ match requested.cmp(&num_permits) {
+ Less => {
+ let delta = num_permits - requested;
+
+ // Request additional permits. If the waiter has been
+ // dequeued, it must be re-queued.
+ if !waiter.try_inc_permits_to_acquire(delta as usize) {
+ let waiter = NonNull::from(&**waiter);
+
+ // Ignore the result. The check for
+ // `permits_to_acquire()` will converge the state as
+ // needed
+ let _ = semaphore.poll_acquire2(delta, || Some(waiter))?;
+ }
+
+ self.state = Waiting(num_permits);
+ }
+ Greater => {
+ let delta = requested - num_permits;
+ let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
+
+ semaphore.add_permits(to_release);
+ self.state = Waiting(num_permits);
+ }
+ Equal => {}
+ }
+
+ if waiter.permits_to_acquire()? == 0 {
+ self.state = Acquired(requested);
+ return Ready(Ok(()));
+ }
+
+ waiter.waker.register_by_ref(cx.waker());
+
+ if waiter.permits_to_acquire()? == 0 {
+ self.state = Acquired(requested);
return Ready(Ok(()));
- } else {
- return Pending;
}
- }
- PermitState::Acquired => {
- return Ready(Ok(()));
- }
- }
- match semaphore.poll_permit(Some((cx, self)))? {
- Ready(()) => {
- self.state = PermitState::Acquired;
- Ready(Ok(()))
- }
- Pending => {
- self.state = PermitState::Waiting;
Pending
}
+ Acquired(acquired) => {
+ if acquired >= num_permits {
+ Ready(Ok(()))
+ } else {
+ match semaphore.poll_acquire(cx, num_permits - acquired, self)? {
+ Ready(()) => {
+ self.state = Acquired(num_permits);
+ Ready(Ok(()))
+ }
+ Pending => {
+ self.state = Waiting(num_permits);
+ Pending
+ }
+ }
+ }
+ }
}
}
/// Try to acquire the permit.
- pub(crate) fn try_acquire(&mut self, semaphore: &Semaphore) -> Result<(), TryAcquireError> {
+ pub(crate) fn try_acquire(
+ &mut self,
+ num_permits: u16,
+ semaphore: &Semaphore,
+ ) -> Result<(), TryAcquireError> {
+ use PermitState::*;
+
match self.state {
- PermitState::Idle => {}
- PermitState::Waiting => {
+ Waiting(requested) => {
+ // There must be a waiter
let waiter = self.waiter.as_ref().unwrap();
- if waiter.acquire2().map_err(to_try_acquire)? {
- self.state = PermitState::Acquired;
- return Ok(());
+ if requested > num_permits {
+ let delta = requested - num_permits;
+ let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
+
+ semaphore.add_permits(to_release);
+ self.state = Waiting(num_permits);
+ }
+
+ let res = waiter.permits_to_acquire().map_err(to_try_acquire)?;
+
+ if res == 0 {
+ if requested < num_permits {
+ // Try to acquire the additional permits
+ semaphore.try_acquire(num_permits - requested)?;
+ }
+
+ self.state = Acquired(num_permits);
+ Ok(())
} else {
- return Err(TryAcquireError::no_permits());
+ Err(TryAcquireError::NoPermits)
}
}
- PermitState::Acquired => {
- return Ok(());
- }
- }
+ Acquired(acquired) => {
+ if acquired < num_permits {
+ semaphore.try_acquire(num_permits - acquired)?;
+ self.state = Acquired(num_permits);
+ }
- match semaphore.poll_permit(None).map_err(to_try_acquire)? {
- Ready(()) => {
- self.state = PermitState::Acquired;
Ok(())
}
- Pending => Err(TryAcquireError::no_permits()),
}
}
/// Release a permit back to the semaphore
- pub(crate) fn release(&mut self, semaphore: &Semaphore) {
- if self.forget2() {
- semaphore.add_permits(1);
- }
+ pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) {
+ let n = self.forget(n);
+ semaphore.add_permits(n as usize);
}
/// Forget the permit **without** releasing it back to the semaphore.
@@ -590,22 +752,37 @@ impl Permit {
///
/// Repeatedly calling `forget` without associated calls to `add_permit`
/// will result in the semaphore losing all permits.
- pub(crate) fn forget(&mut self) {
- self.forget2();
- }
+ ///
+ /// Will forget **at most** the number of acquired permits. This number is
+ /// returned.
+ pub(crate) fn forget(&mut self, n: u16) -> u16 {
+ use PermitState::*;
- /// Returns `true` if the permit was acquired
- fn forget2(&mut self) -> bool {
match self.state {
- PermitState::Idle => false,
- PermitState::Waiting => {
- let ret = self.waiter.as_ref().unwrap().cancel_interest();
- self.state = PermitState::Idle;
- ret
+ Waiting(requested) => {
+ let n = cmp::min(n, requested);
+
+ // Decrement
+ let acquired = self
+ .waiter
+ .as_ref()
+ .unwrap()
+ .try_dec_permits_to_acquire(n as usize) as u16;
+
+ if n == requested {
+ self.state = Acquired(0);
+ } else if acquired == requested - n {
+ self.state = Waiting(acquired);
+ } else {
+ self.state = Waiting(requested - n);
+ }
+
+ acquired
}
- PermitState::Acquired => {
- self.state = PermitState::Idle;
- true
+ Acquired(acquired) => {
+ let n = cmp::min(n, acquired);
+ self.state = Acquired(acquired - n);
+ n
}
}
}
@@ -617,6 +794,20 @@ impl Default for Permit {
}
}
+impl Drop for Permit {
+ fn drop(&mut self) {
+ if let Some(waiter) = self.waiter.take() {
+ // Set the dropped flag
+ let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel));
+
+ if state.is_queued() {
+ // The waiter is stored in the queue. The semaphore will drop it
+ std::mem::forget(waiter);
+ }
+ }
+ }
+}
+
// ===== impl AcquireError ====
impl AcquireError {
@@ -626,7 +817,7 @@ impl AcquireError {
}
fn to_try_acquire(_: AcquireError) -> TryAcquireError {
- TryAcquireError::closed()
+ TryAcquireError::Closed
}
impl fmt::Displa