From cf025ba45f68934ae2138bb75ee2a5ee50506d1b Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Thu, 24 Sep 2020 17:26:38 -0700 Subject: sync: support mpsc send with `&self` (#2861) Updates the mpsc channel to use the intrusive waker based sempahore. This enables using `Sender` with `&self`. Instead of using `Sender::poll_ready` to ensure capacity and updating the `Sender` state, `async fn Sender::reserve()` is added. This function returns a `Permit` value representing the reserved capacity. Fixes: #2637 Refs: #2718 (intrusive waiters) --- tokio-test/src/io.rs | 4 +- tokio/src/signal/unix.rs | 30 +- tokio/src/stream/mod.rs | 4 +- tokio/src/stream/stream_map.rs | 4 +- tokio/src/sync/batch_semaphore.rs | 10 +- tokio/src/sync/mod.rs | 9 +- tokio/src/sync/mpsc/bounded.rs | 338 ++++---- tokio/src/sync/mpsc/chan.rs | 268 +------ tokio/src/sync/mpsc/error.rs | 20 - tokio/src/sync/mpsc/mod.rs | 2 +- tokio/src/sync/mpsc/unbounded.rs | 39 +- tokio/src/sync/semaphore_ll.rs | 1221 ----------------------------- tokio/src/sync/tests/loom_mpsc.rs | 14 +- tokio/src/sync/tests/loom_semaphore_ll.rs | 192 ----- tokio/src/sync/tests/mod.rs | 2 - tokio/src/sync/tests/semaphore_ll.rs | 470 ----------- tokio/src/util/linked_list.rs | 20 +- tokio/tests/rt_threaded.rs | 10 +- tokio/tests/sync_mpsc.rs | 363 ++++----- 19 files changed, 459 insertions(+), 2561 deletions(-) delete mode 100644 tokio/src/sync/semaphore_ll.rs delete mode 100644 tokio/src/sync/tests/loom_semaphore_ll.rs delete mode 100644 tokio/src/sync/tests/semaphore_ll.rs diff --git a/tokio-test/src/io.rs b/tokio-test/src/io.rs index 4f0b5897..b91ddc34 100644 --- a/tokio-test/src/io.rs +++ b/tokio-test/src/io.rs @@ -200,7 +200,9 @@ impl Inner { } fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.rx.poll_recv(cx) + use futures_core::stream::Stream; + + Pin::new(&mut self.rx).poll_next(cx) } fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> { diff --git a/tokio/src/signal/unix.rs b/tokio/src/signal/unix.rs index 45a091d7..30a05872 100644 --- a/tokio/src/signal/unix.rs +++ b/tokio/src/signal/unix.rs @@ -391,35 +391,7 @@ impl Signal { poll_fn(|cx| self.poll_recv(cx)).await } - /// Polls to receive the next signal notification event, outside of an - /// `async` context. - /// - /// `None` is returned if no more events can be received by this stream. - /// - /// # Examples - /// - /// Polling from a manually implemented future - /// - /// ```rust,no_run - /// use std::pin::Pin; - /// use std::future::Future; - /// use std::task::{Context, Poll}; - /// use tokio::signal::unix::Signal; - /// - /// struct MyFuture { - /// signal: Signal, - /// } - /// - /// impl Future for MyFuture { - /// type Output = Option<()>; - /// - /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - /// println!("polling MyFuture"); - /// self.signal.poll_recv(cx) - /// } - /// } - /// ``` - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { self.rx.poll_recv(cx) } } diff --git a/tokio/src/stream/mod.rs b/tokio/src/stream/mod.rs index 6a99d9d8..59e1482f 100644 --- a/tokio/src/stream/mod.rs +++ b/tokio/src/stream/mod.rs @@ -270,8 +270,8 @@ pub trait StreamExt: Stream { /// # #[tokio::main(basic_scheduler)] /// async fn main() { /// # time::pause(); - /// let (mut tx1, rx1) = mpsc::channel(10); - /// let (mut tx2, rx2) = mpsc::channel(10); + /// let (tx1, rx1) = mpsc::channel(10); + /// let (tx2, rx2) = mpsc::channel(10); /// /// let mut rx = rx1.merge(rx2); /// diff --git a/tokio/src/stream/stream_map.rs b/tokio/src/stream/stream_map.rs index 2f60ea4d..a1c80f15 100644 --- a/tokio/src/stream/stream_map.rs +++ b/tokio/src/stream/stream_map.rs @@ -57,8 +57,8 @@ use std::task::{Context, Poll}; /// /// #[tokio::main] /// async fn main() { -/// let (mut tx1, rx1) = mpsc::channel(10); -/// let (mut tx2, rx2) = mpsc::channel(10); +/// let (tx1, rx1) = mpsc::channel(10); +/// let (tx2, rx2) = mpsc::channel(10); /// /// tokio::spawn(async move { /// tx1.send(1).await.unwrap(); diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index a1048ca3..9f324f9c 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -165,7 +165,6 @@ impl Semaphore { /// permits and notifies all pending waiters. // This will be used once the bounded MPSC is updated to use the new // semaphore implementation. - #[allow(dead_code)] pub(crate) fn close(&self) { let mut waiters = self.waiters.lock().unwrap(); // If the semaphore's permits counter has enough permits for an @@ -185,6 +184,11 @@ impl Semaphore { } } + /// Returns true if the semaphore is closed + pub(crate) fn is_closed(&self) -> bool { + self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED + } + pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { assert!( num_permits as usize <= Self::MAX_PERMITS, @@ -194,8 +198,8 @@ impl Semaphore { let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; let mut curr = self.permits.load(Acquire); loop { - // Has the semaphore closed?git - if curr & Self::CLOSED > 0 { + // Has the semaphore closed? + if curr & Self::CLOSED == Self::CLOSED { return Err(TryAcquireError::Closed); } diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 4c069467..6531931b 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -106,7 +106,7 @@ //! //! #[tokio::main] //! async fn main() { -//! let (mut tx, mut rx) = mpsc::channel(100); +//! let (tx, mut rx) = mpsc::channel(100); //! //! tokio::spawn(async move { //! for i in 0..10 { @@ -150,7 +150,7 @@ //! for _ in 0..10 { //! // Each task needs its own `tx` handle. This is done by cloning the //! // original handle. -//! let mut tx = tx.clone(); +//! let tx = tx.clone(); //! //! tokio::spawn(async move { //! tx.send(&b"data to write"[..]).await.unwrap(); @@ -213,7 +213,7 @@ //! //! // Spawn tasks that will send the increment command. //! for _ in 0..10 { -//! let mut cmd_tx = cmd_tx.clone(); +//! let cmd_tx = cmd_tx.clone(); //! //! join_handles.push(tokio::spawn(async move { //! let (resp_tx, resp_rx) = oneshot::channel(); @@ -443,7 +443,6 @@ cfg_sync! { pub mod oneshot; pub(crate) mod batch_semaphore; - pub(crate) mod semaphore_ll; mod semaphore; pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; @@ -473,7 +472,7 @@ cfg_not_sync! { cfg_signal_internal! { pub(crate) mod mpsc; - pub(crate) mod semaphore_ll; + pub(crate) mod batch_semaphore; } } diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 14e4731a..2d2006d5 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -1,6 +1,6 @@ +use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{ClosedError, SendError, TryRecvError, TrySendError}; -use crate::sync::semaphore_ll as semaphore; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; cfg_time! { use crate::sync::mpsc::error::SendTimeoutError; @@ -8,6 +8,7 @@ cfg_time! { } use std::fmt; +#[cfg(any(feature = "signal", feature = "process", feature = "stream"))] use std::task::{Context, Poll}; /// Send values to the associated `Receiver`. @@ -17,20 +18,14 @@ pub struct Sender { chan: chan::Tx, } -impl Clone for Sender { - fn clone(&self) -> Self { - Sender { - chan: self.chan.clone(), - } - } -} - -impl fmt::Debug for Sender { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Sender") - .field("chan", &self.chan) - .finish() - } +/// Permit to send one value into the channel. +/// +/// `Permit` values are returned by [`Sender::reserve()`] and are used to +/// guarantee channel capacity before generating a message to send. +/// +/// [`Sender::reserve()`]: Sender::reserve +pub struct Permit<'a, T> { + chan: &'a chan::Tx, } /// Receive values from the associated `Sender`. @@ -41,14 +36,6 @@ pub struct Receiver { chan: chan::Rx, } -impl fmt::Debug for Receiver { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Receiver") - .field("chan", &self.chan) - .finish() - } -} - /// Creates a bounded mpsc channel for communicating between asynchronous tasks /// with backpressure. /// @@ -77,7 +64,7 @@ impl fmt::Debug for Receiver { /// /// #[tokio::main] /// async fn main() { -/// let (mut tx, mut rx) = mpsc::channel(100); +/// let (tx, mut rx) = mpsc::channel(100); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -125,7 +112,7 @@ impl Receiver { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(100); + /// let (tx, mut rx) = mpsc::channel(100); /// /// tokio::spawn(async move { /// tx.send("hello").await.unwrap(); @@ -143,7 +130,7 @@ impl Receiver { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(100); + /// let (tx, mut rx) = mpsc::channel(100); /// /// tx.send("hello").await.unwrap(); /// tx.send("world").await.unwrap(); @@ -154,12 +141,11 @@ impl Receiver { /// ``` pub async fn recv(&mut self) -> Option { use crate::future::poll_fn; - - poll_fn(|cx| self.poll_recv(cx)).await + poll_fn(|cx| self.chan.recv(cx)).await } - #[doc(hidden)] // TODO: document - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + #[cfg(any(feature = "signal", feature = "process"))] + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { self.chan.recv(cx) } @@ -178,7 +164,7 @@ impl Receiver { /// use tokio::sync::mpsc; /// /// fn main() { - /// let (mut tx, mut rx) = mpsc::channel::(10); + /// let (tx, mut rx) = mpsc::channel::(10); /// /// let sync_code = thread::spawn(move || { /// assert_eq!(Some(10), rx.blocking_recv()); @@ -215,12 +201,53 @@ impl Receiver { /// Closes the receiving half of a channel, without dropping it. /// /// This prevents any further messages from being sent on the channel while - /// still enabling the receiver to drain messages that are buffered. + /// still enabling the receiver to drain messages that are buffered. Any + /// outstanding [`Permit`] values will still be able to send messages. + /// + /// In order to guarantee no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. + /// + /// [`Permit`]: Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(20); + /// + /// tokio::spawn(async move { + /// let mut i = 0; + /// while let Ok(permit) = tx.reserve().await { + /// permit.send(i); + /// i += 1; + /// } + /// }); + /// + /// rx.close(); + /// + /// while let Some(msg) = rx.recv().await { + /// println!("got {}", msg); + /// } + /// + /// // Channel closed and no messages are lost. + /// } + /// ``` pub fn close(&mut self) { self.chan.close(); } } +impl fmt::Debug for Receiver { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver") + .field("chan", &self.chan) + .finish() + } +} + impl Unpin for Receiver {} cfg_stream! { @@ -228,7 +255,7 @@ cfg_stream! { type Item = T; fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_recv(cx) + self.chan.recv(cx) } } } @@ -267,7 +294,7 @@ impl Sender { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(1); + /// let (tx, mut rx) = mpsc::channel(1); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -283,17 +310,13 @@ impl Sender { /// } /// } /// ``` - pub async fn send(&mut self, value: T) -> Result<(), SendError> { - use crate::future::poll_fn; - - if poll_fn(|cx| self.poll_ready(cx)).await.is_err() { - return Err(SendError(value)); - } - - match self.try_send(value) { - Ok(()) => Ok(()), - Err(TrySendError::Full(_)) => unreachable!(), - Err(TrySendError::Closed(value)) => Err(SendError(value)), + pub async fn send(&self, value: T) -> Result<(), SendError> { + match self.reserve().await { + Ok(permit) => { + permit.send(value); + Ok(()) + } + Err(_) => Err(SendError(value)), } } @@ -304,9 +327,6 @@ impl Sender { /// with [`send`], this function has two failure cases instead of one (one for /// disconnection, one for a full buffer). /// - /// This function may be paired with [`poll_ready`] in order to wait for - /// channel capacity before trying to send a value. - /// /// # Errors /// /// If the channel capacity has been reached, i.e., the channel has `n` @@ -318,7 +338,6 @@ impl Sender { /// an error. The error includes the value passed to `send`. /// /// [`send`]: Sender::send - /// [`poll_ready`]: Sender::poll_ready /// [`channel`]: channel /// [`close`]: Receiver::close /// @@ -330,8 +349,8 @@ impl Sender { /// #[tokio::main] /// async fn main() { /// // Create a channel with buffer size 1 - /// let (mut tx1, mut rx) = mpsc::channel(1); - /// let mut tx2 = tx1.clone(); + /// let (tx1, mut rx) = mpsc::channel(1); + /// let tx2 = tx1.clone(); /// /// tokio::spawn(async move { /// tx1.send(1).await.unwrap(); @@ -359,8 +378,15 @@ impl Sender { /// } /// } /// ``` - pub fn try_send(&mut self, message: T) -> Result<(), TrySendError> { - self.chan.try_send(message)?; + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)), + } + + // Send the message + self.chan.send(message); Ok(()) } @@ -392,7 +418,7 @@ impl Sender { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(1); + /// let (tx, mut rx) = mpsc::channel(1); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -412,27 +438,22 @@ impl Sender { #[cfg(feature = "time")] #[cfg_attr(docsrs, doc(cfg(feature = "time")))] pub async fn send_timeout( - &mut self, + &self, value: T, timeout: Duration, ) -> Result<(), SendTimeoutError> { - use crate::future::poll_fn; - - match crate::time::timeout(timeout, poll_fn(|cx| self.poll_ready(cx))).await { + let permit = match crate::time::timeout(timeout, self.reserve()).await { Err(_) => { return Err(SendTimeoutError::Timeout(value)); } Ok(Err(_)) => { return Err(SendTimeoutError::Closed(value)); } - Ok(_) => {} - } + Ok(Ok(permit)) => permit, + }; - match self.try_send(value) { - Ok(()) => Ok(()), - Err(TrySendError::Full(_)) => unreachable!(), - Err(TrySendError::Closed(value)) => Err(SendTimeoutError::Closed(value)), - } + permit.send(value); + Ok(()) } /// Blocking send to call outside of asynchronous contexts. @@ -450,7 +471,7 @@ impl Sender { /// use tokio::sync::mpsc; /// /// fn main() { - /// let (mut tx, mut rx) = mpsc::channel::(1); + /// let (tx, mut rx) = mpsc::channel::(1); /// /// let sync_code = thread::spawn(move || { /// tx.blocking_send(10).unwrap(); @@ -462,92 +483,139 @@ impl Sender { /// sync_code.join().unwrap() /// } /// ``` - pub fn blocking_send(&mut self, value: T) -> Result<(), SendError> { + pub fn blocking_send(&self, value: T) -> Result<(), SendError> { let mut enter_handle = crate::runtime::enter::enter(false); enter_handle.block_on(self.send(value)).unwrap() } - /// Returns `Poll::Ready(Ok(()))` when the channel is able to accept another item. + /// Wait for channel capacity. Once capacity to send one message is + /// available, it is reserved for the caller. /// - /// If the channel is full, then `Poll::Pending` is returned and the task is notified when a - /// slot becomes available. + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. A [`Permit`] is returned to track + /// the reserved capacity. The [`send`] function on [`Permit`] consumes the + /// reserved capacity. /// - /// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a call to `try_send` will succeed unless - /// the channel has since been closed. To provide this guarantee, the channel reserves one slot - /// in the channel for the coming send. This reserved slot is not available to other `Sender` - /// instances, so you need to be careful to not end up with deadlocks by blocking after calling - /// `poll_ready` but before sending an element. + /// Dropping [`Permit`] without sending a message releases the capacity back + /// to the channel. /// - /// If, after `poll_ready` succeeds, you decide you do not wish to send an item after all, you - /// can use [`disarm`](Sender::disarm) to release the reserved slot. + /// [`Permit`]: Permit + /// [`send`]: Permit::send /// - /// Until an item is sent or [`disarm`](Sender::disarm) is called, repeated calls to - /// `poll_ready` will return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))` if channel - /// is closed. - pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.chan.poll_ready(cx).map_err(|_| ClosedError::new()) - } - - /// Undo a successful call to `poll_ready`. + /// # Examples /// - /// Once a call to `poll_ready` returns `Poll::Ready(Ok(()))`, it holds up one slot in the - /// channel to make room for the coming send. `disarm` allows you to give up that slot if you - /// decide you do not wish to send an item after all. After calling `disarm`, you must call - /// `poll_ready` until it returns `Poll::Ready(Ok(()))` before attempting to send again. + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); /// - /// Returns `false` if no slot is reserved for this sender (usually because `poll_ready` was - /// not previously called, or did not succeed). + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); /// - /// # Motivation + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); /// - /// Since `poll_ready` takes up one of the finite number of slots in a bounded channel, callers - /// need to send an item shortly after `poll_ready` succeeds. If they do not, idle senders may - /// take up all the slots of the channel, and prevent active senders from getting any requests - /// through. Consider this code that forwards from one channel to another: + /// // Sending on the permit succeeds + /// permit.send(456); /// - /// ```rust,ignore - /// loop { - /// ready!(tx.poll_ready(cx))?; - /// if let Some(item) = ready!(rx.poll_recv(cx)) { - /// tx.try_send(item)?; - /// } else { - /// break; - /// } + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); /// } /// ``` + pub async fn reserve(&self) -> Result, SendError<()>> { + match self.chan.semaphore().0.acquire(1).await { + Ok(_) => {} + Err(_) => return Err(SendError(())), + } + + Ok(Permit { chan: &self.chan }) + } +} + +impl Clone for Sender { + fn clone(&self) -> Self { + Sender { + chan: self.chan.clone(), + } + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender") + .field("chan", &self.chan) + .finish() + } +} + +// ===== impl Permit ===== + +impl Permit<'_, T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); /// - /// If many such forwarders exist, and they all forward into a single (cloned) `Sender`, then - /// any number of forwarders may be waiting for `rx.poll_recv` at the same time. While they do, - /// they are effectively each reducing the channel's capacity by 1. If enough of these - /// forwarders are idle, forwarders whose `rx` _do_ have elements will be unable to find a spot - /// for them through `poll_ready`, and the system will deadlock. - /// - /// `disarm` solves this problem by allowing you to give up the reserved slot if you find that - /// you have to block. We can then fix the code above by writing: - /// - /// ```rust,ignore - /// loop { - /// ready!(tx.poll_ready(cx))?; - /// let item = rx.poll_recv(cx); - /// if let Poll::Ready(Ok(_)) = item { - /// // we're going to send the item below, so don't disarm - /// } else { - /// // give up our send slot, we won't need it for a while - /// tx.disarm(); - /// } - /// if let Some(item) = ready!(item) { - /// tx.try_send(item)?; - /// } else { - /// break; - /// } + /// // Send a message on the permit + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); /// } /// ``` - pub fn disarm(&mut self) -> bool { - if self.chan.is_ready() { - self.chan.disarm(); - true - } else { - false + pub fn send(self, value: T) { + use std::mem; + + self.chan.send(value); + + // Avoid the drop logic + mem::forget(self); + } +} + +impl Drop for Permit<'_, T> { + fn drop(&mut self) { + use chan::Semaphore; + + let semaphore = self.chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + if semaphore.is_closed() && semaphore.is_idle() { + self.chan.wake_rx(); } } } + +impl fmt::Debug for Permit<'_, T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Permit") + .field("chan", &self.chan) + .finish() + } +} diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 0a53cda2..2d3f0149 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -2,8 +2,8 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; -use crate::sync::mpsc::error::{ClosedError, TryRecvError}; -use crate::sync::mpsc::{error, list}; +use crate::sync::mpsc::error::TryRecvError; +use crate::sync::mpsc::list; use std::fmt; use std::process; @@ -12,21 +12,13 @@ use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; /// Channel sender -pub(crate) struct Tx { +pub(crate) struct Tx { inner: Arc>, - permit: S::Permit, } -impl fmt::Debug for Tx -where - S::Permit: fmt::Debug, - S: fmt::Debug, -{ +impl fmt::Debug for Tx { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Tx") - .field("inner", &self.inner) - .field("permit", &self.permit) - .finish() + fmt.debug_struct("Tx").field("inner", &self.inner).finish() } } @@ -35,71 +27,20 @@ pub(crate) struct Rx { inner: Arc>, } -impl fmt::Debug for Rx -where - S: fmt::Debug, -{ +impl fmt::Debug for Rx { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Rx").field("inner", &self.inner).finish() } } -#[derive(Debug, Eq, PartialEq)] -pub(crate) enum TrySendError { - Closed, - Full, -} - -impl From<(T, TrySendError)> for error::SendError { - fn from(src: (T, TrySendError)) -> error::SendError { - match src.1 { - TrySendError::Closed => error::SendError(src.0), - TrySendError::Full => unreachable!(), - } - } -} - -impl From<(T, TrySendError)> for error::TrySendError { - fn from(src: (T, TrySendError)) -> error::TrySendError { - match src.1 { - TrySendError::Closed => error::TrySendError::Closed(src.0), - TrySendError::Full => error::TrySendError::Full(src.0), - } - } -} - pub(crate) trait Semaphore { - type Permit; - - fn new_permit() -> Self::Permit; - - /// The permit is dropped without a value being sent. In this case, the - /// permit must be returned to the semaphore. - /// - /// # Return - /// - /// Returns true if the permit was acquired. - fn drop_permit(&self, permit: &mut Self::Permit) -> bool; - fn is_idle(&self) -> bool; fn add_permit(&self); - fn poll_acquire( - &self, - cx: &mut Context<'_>, - permit: &mut Self::Permit, - ) -> Poll>; - - fn try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>; - - /// A value was sent into the channel and the permit held by `tx` is - /// dropped. In this case, the permit should not immeditely be returned to - /// the semaphore. Instead, the permit is returnred to the semaphore once - /// the sent value is read by the rx handle. - fn forget(&self, permit: &mut Self::Permit); - fn close(&self); + + fn is_closed(&self) -> bool; } struct Chan { @@ -157,10 +98,7 @@ impl fmt::Debug for RxFields { unsafe impl Send for Chan {} unsafe impl Sync for Chan {} -pub(crate) fn channel(semaphore: S) -> (Tx, Rx) -where - S: Semaphore, -{ +pub(crate) fn channel(semaphore: S) -> (Tx, Rx) { let (tx, rx) = list::channel(); let chan = Arc::new(Chan { @@ -179,48 +117,27 @@ where // ===== impl Tx ===== -impl Tx -where - S: Semaphore, -{ +impl Tx { fn new(chan: Arc>) -> Tx { - Tx { - inner: chan, - permit: S::new_permit(), - } + Tx { inner: chan } } - pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.semaphore.poll_acquire(cx, &mut self.permit) - } - - pub(crate) fn disarm(&mut self) { - // TODO: should this error if not acquired? - self.inner.semaphore.drop_permit(&mut self.permit); + pub(super) fn semaphore(&self) -> &S { + &self.inner.semaphore } /// Send a message and notify the receiver. - pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> { - self.inner.try_send(value, &mut self.permit) - } -} - -impl Tx { - pub(crate) fn is_ready(&self) -> bool { - self.permit.is_acquired() + pub(crate) fn send(&self, value: T) { + self.inner.send(value); } -} -impl Tx { - pub(crate) fn send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)> { - self.inner.try_send(value, &mut ()) + /// Wake the receive half + pub(crate) fn wake_rx(&self) { + self.inner.rx_waker.wake(); } } -impl Clone for Tx -where - S: Semaphore, -{ +impl Clone for Tx { fn clone(&self) -> Tx { // Using a Relaxed ordering here is sufficient as the caller holds a // strong ref to `self`, preventing a concurrent decrement to zero. @@ -228,22 +145,12 @@ where Tx { inner: self.inner.clone(), - permit: S::new_permit(), } } } -impl Drop for Tx -where - S: Semaphore, -{ +impl Drop for Tx { fn drop(&mut self) { - let notify = self.inner.semaphore.drop_permit(&mut self.permit); - - if notify && self.inner.semaphore.is_idle() { - self.inner.rx_waker.wake(); - } - if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 { return; } @@ -252,16 +159,13 @@ where self.inner.tx.close(); // Notify the receiver - self.inner.rx_waker.wake(); + self.wake_rx(); } } // ===== impl Rx ===== -impl Rx -where - S: Semaphore, -{ +impl Rx { fn new(chan: Arc>) -> Rx { Rx { inner: chan } } @@ -349,10 +253,7 @@ where } } -impl Drop for Rx -where - S: Semaphore, -{ +impl Drop for Rx { fn drop(&mut self) { use super::block::Read::Value; @@ -370,25 +271,13 @@ where // ===== impl Chan ===== -impl Chan -where - S: Semaphore, -{ - fn try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)> { - if let Err(e) = self.semaphore.try_acquire(permit) { - return Err((value, e)); - } - +impl Chan { + fn send(&self, value: T) { // Push the value self.tx.push(value); // Notify the rx task self.rx_waker.wake(); - - // Release the permit - self.semaphore.forget(permit); - - Ok(()) } } @@ -407,74 +296,24 @@ impl Drop for Chan { } } -use crate::sync::semaphore_ll::TryAcquireError; - -impl From for TrySendError { - fn from(src: TryAcquireError) -> TrySendError { - if src.is_closed() { - TrySendError::Closed - } else if src.is_no_permits() { - TrySendError::Full - } else { - unreachable!(); - } - } -} - // ===== impl Semaphore for (::Semaphore, capacity) ===== -use crate::sync::semaphore_ll::Permit; - -impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { - type Permit = Permit; - - fn new_permit() -> Permit { - Permit::new() - } - - fn drop_permit(&self, permit: &mut Permit) -> bool { - let ret = permit.is_acquired(); - permit.release(1, &self.0); - ret - } - +impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { fn add_permit(&self) { - self.0.add_permits(1) + self.0.release(1) } fn is_idle(&self) -> bool { self.0.available_permits() == self.1 } - fn poll_acquire( - &self, - cx: &mut Context<'_>, - permit: &mut Permit, - ) -> Poll> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - permit - .poll_acquire(cx, 1, &self.0) - .map_err(|_| ClosedError::new()) - .map(move |r| { - coop.made_progress(); - r - }) - } - - fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { - permit.try_acquire(1, &self.0)?; - Ok(()) - } - - fn forget(&self, permit: &mut Self::Permit) { - permit.forget(1); - } - fn close(&self) { self.0.close(); } + + fn is_closed(&self) -> bool { + self.0.is_closed() + } } // ===== impl Semaphore for AtomicUsize ===== @@ -483,14 +322,6 @@ use std::sync::atomic::Ordering::{Acquire, Release}; use std::usize; impl Semaphore for AtomicUsize { - type Permit = (); - - fn new_permit() {} - - fn drop_permit(&self, _permit: &mut ()) -> bool { - false - } - fn add_permit(&self) { let prev = self.fetch_sub(2, Release); @@ -504,40 +335,11 @@ impl Semaphore for AtomicUsize { self.load(Acquire) >> 1 == 0 } - fn poll_acquire( - &self, - _cx: &mut Context<'_>, - permit: &mut (), - ) -> Poll> { - Ready(self.try_acquire(permit).map_err(|_| ClosedError::new())) - } - - fn try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError> { - let mut curr = self.load(Acquire); - - loop { - if curr & 1 == 1 { - return Err(TrySendError::Closed); - } - - if curr == usize::MAX ^ 1 { - // Overflowed the ref count. There is no safe way to recover, so - // abort the process. In practice, this should never happen. - process::abort() - } - - match self.compare_exchange(curr, curr + 2, AcqRel, Acquire) { - Ok(_) => return Ok(()), - Err(actual) => { - curr = actual; - } - } - } - } - - fn forget(&self, _permit: &mut ()) {} - fn close(&self) { self.fetch_or(1, Release); } + + fn is_closed(&self) -> bool { + self.load(Acquire) & 1 == 1 + } } diff --git a/tokio/src/sync/mpsc/error.rs b/tokio/src/sync/mpsc/error.rs index 72c42aa5..77054529 100644 --- a/tokio/src/sync/mpsc/error.rs +++ b/tokio/src/sync/mpsc/error.rs @@ -94,26 +94,6 @@ impl fmt::Display for TryRecvError { impl Error for TryRecvError {} -// ===== ClosedError ===== - -/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready). -#[derive(Debug)] -pub struct ClosedError(()); - -impl ClosedError { - pub(crate) fn new() -> ClosedError { - ClosedError(()) - } -} - -impl fmt::Display for ClosedError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "channel closed") - } -} - -impl Error for ClosedError {} - cfg_time! { // ===== SendTimeoutError ===== diff --git a/tokio/src/sync/mpsc/mod.rs b/tokio/src/sync/mpsc/mod.rs index 7e663da8..a2bcf83b 100644 --- a/tokio/src/sync/mpsc/mod.rs +++ b/tokio/src/sync/mpsc/mod.rs @@ -76,7 +76,7 @@ pub(super) mod block; mod bounded; -pub use self::bounded::{channel, Receiver, Sender}; +pub use self::bounded::{channel, Permit, Receiver, Sender}; mod chan; diff --git a/tokio/src/sync/mpsc/unbounded.rs b/tokio/src/sync/mpsc/unbounded.rs index 6b2ca722..59456375 100644 --- a/tokio/src/sync/mpsc/unbounded.rs +++ b/tokio/src/sync/mpsc/unbounded.rs @@ -73,8 +73,7 @@ impl UnboundedReceiver { UnboundedReceiver { chan } } - #[doc(hidden)] // TODO: doc - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { self.chan.recv(cx) } @@ -174,7 +173,41 @@ impl UnboundedSender { /// [`close`]: UnboundedReceiver::close /// [`UnboundedReceiver`]: UnboundedReceiver pub fn send(&self, message: T) -> Result<(), SendError> { - self.chan.send_unbounded(message)?; + if !self.inc_num_messages() { + return Err(SendError(message)); + } + + self.chan.send(message); Ok(()) } + + fn inc_num_messages(&self) -> bool { + use std::process; + use std::sync::atomic::Ordering::{AcqRel, Acquire}; + + let mut curr = self.chan.semaphore().load(Acquire); + + loop { + if curr & 1 == 1 { + return false; + } + + if curr == usize::MAX ^ 1 { + // Overflowed the ref count. There is no safe way to recover, so + // abort the process. In practice, this should never happen. + process::abort() + } + + match self + .chan + .semaphore() + .compare_exchange(curr, curr + 2, AcqRel, Acquire) + { + Ok(_) => return true, + Err(actual) => { + curr = actual; + } + } + } + } } diff --git a/tokio/src/sync/semaphore_ll.rs b/tokio/src/sync/semaphore_ll.rs deleted file mode 100644 index f044095f..00000000 --- a/tokio/src/sync/semaphore_ll.rs +++ /dev/null @@ -1,1221 +0,0 @@ -#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] - -//! Thread-safe, asynchronous counting semaphore. -//! -//! A `Semaphore` instance holds a set of permits. Permits are used to -//! synchronize access to a shared resource. -//! -//! Before accessing the shared resource, callers acquire a permit from the -//! semaphore. Once the permit is acquired, the caller then enters the critical -//! section. If no permits are available, then acquiring the semaphore returns -//! `Pending`. The task is woken once a permit becomes available. - -use crate::loom::cell::UnsafeCell; -use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; -use crate::loom::thread; - -use std::cmp; -use std::fmt; -use std::ptr::{self, NonNull}; -use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release}; -use std::task::Poll::{Pending, Ready}; -use std::task::{Context, Poll}; -use std::usize; - -/// Futures-aware semaphore. -pub(crate) struct Semaphore { - /// Tracks both the waiter queue tail pointer and the number of remaining - /// permits. - state: AtomicUsize, - - /// waiter queue head pointer. - head: UnsafeCell>, - - /// Coordinates access to the queue head. - rx_lock: AtomicUsize, - - /// Stub waiter node used as part of the MPSC channel algorithm. - stub: Box, -} - -/// A semaphore permit -/// -/// Tracks the lifecycle of a semaphore permit. -/// -/// An instance of `Permit` is intended to be used with a **single** instance of -/// `Semaphore`. Using a single instance of `Permit` with multiple semaphore -/// instances will result in unexpected behavior. -/// -/// `Permit` does **not** release the permit back to the semaphore on drop. It -/// is the user's responsibility to ensure that `Permit::release` is called -/// before dropping the permit. -#[derive(Debug)] -pub(crate) struct Permit { - waiter: Option>, - state: PermitState, -} - -/// Error returned by `Permit::poll_acquire`. -#[derive(Debug)] -pub(crate) struct AcquireError(()); - -/// Error returned by `Permit::try_acquire`. -#[derive(Debug)] -pub(crate) enum TryAcquireError { - Closed, - NoPermits, -} - -/// Node used to notify the semaphore waiter when permit is available. -#[derive(Debug)] -struct Waiter { - /// Stores waiter state. - /// - /// See `WaiterState` for more details. - state: AtomicUsize, - - /// Task to wake when a permit is made available. - waker: AtomicWaker, - - /// Next pointer in the queue of waiting senders. - next: AtomicPtr, -} - -/// Semaphore state -/// -/// The 2 low bits track the modes. -/// -/// - Closed -/// - Full -/// -/// When not full, the rest of the `usize` tracks the total number of messages -/// in the channel. When full, the rest of the `usize` is a pointer to the tail -/// of the "waiting senders" queue. -#[derive(Copy, Clone)] -struct SemState(usize); - -/// Permit state -#[derive(Debug, Copy, Clone)] -enum PermitState { - /// Currently waiting for permits to be made available and assigned to the - /// waiter. - Waiting(u16), - - /// The number of acquired permits - Acquired(u16), -} - -/// State for an individual waker node -#[derive(Debug, Copy, Clone)] -struct WaiterState(usize); - -/// Waiter node is in the semaphore queue -const QUEUED: usize = 0b001; - -/// Semaphore has been closed, no more permits will be issued. -const CLOSED: usize = 0b10; - -/// The permit that owns the `Waiter` dropped. -const DROPPED: usize = 0b100; - -/// Represents "one requested permit" in the waiter state -const PERMIT_ONE: usize = 0b1000; - -/// Masks the waiter state to only contain bits tracking number of requested -/// permits. -const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1); - -/// How much to shift a permit count to pack it into the waker state -const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros(); - -/// Flag differentiating between available permits and waiter pointers. -/// -/// If we assume pointers are properly aligned, then the least significant bit -/// will always be zero. So, we use that bit to track if the value represents a -/// number. -const NUM_FLAG: usize = 0b01; - -/// Signal the semaphore is closed -const CLOSED_FLAG: usize = 0b10; - -/// Maximum number of permits a semaphore can manage -const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT; - -/// When representing "numbers", the state has to be shifted this much (to get -/// rid of the flag bit). -const NUM_SHIFT: usize = 2; - -// ===== impl Semaphore ===== - -impl Semaphore { - /// Creates a new semaphore with the initial number of permits - /// - /// # Panics - /// - /// Panics if `permits` is zero. - pub(crate) fn new(permits: usize) -> Semaphore { - let stub = Box::new(Waiter::new()); - let ptr = NonNull::from(&*stub); - - // Allocations are aligned - debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0); - - let state = SemState::new(permits, &stub); - - Semaphore { - state: AtomicUsize::new(state.to_usize()), - head: UnsafeCell::new(ptr), - rx_lock: AtomicUsize::new(0), - stub, - } - } - - /// Returns the current number of available permits - pub(crate) fn available_permits(&self) -> usize { - let curr = SemState(self.state.load(Acquire)); - curr.available_permits() - } - - /// Tries to acquire the requested number of permits, registering the waiter - /// if not enough permits are available. - fn poll_acquire( - &self, - cx: &mut Context<'_>, - num_permits: u16, - permit: &mut Permit, - ) -> Poll> { - self.poll_acquire2(num_permits, || { - let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new())); - - waiter.waker.register_by_ref(cx.waker()); - - Some(NonNull::from(&**waiter)) - }) - } - - fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { - match self.poll_acquire2(num_permits, || None) { - Poll::Ready(res) => res.map_err(to_try_acquire), - Poll::Pending => Err(TryAcquireError::NoPermits), - } - } - - /// Polls for a permit - /// - /// Tries to acquire available permits first. If unable to acquire a - /// sufficient number of permits, the caller's waiter is pushed onto the - /// semaphore's wait queue. - fn poll_acquire2( - &self, - num_permits: u16, - mut get_waiter: F, - ) -> Poll> - where - F: FnMut() -> Option>, - { - let num_permits = num_permits as usize; - - // Load the current state - let mut curr = SemState(self.state.load(Acquire)); - - // Saves a ref to the waiter node - let mut maybe_waiter: Option> = None; - - /// Used in branches where we attempt to push the waiter into the wait - /// queue but fail due to permits becoming available or the wait queue - /// transitioning to "closed". In this case, the waiter must be - /// transitioned back to the "idle" state. - macro_rules! revert_to_idle { - () => { - if let Some(waiter) = maybe_waiter { - unsafe { waiter.as_ref() }.revert_to_idle(); - } - }; - } - - loop { - let mut next = curr; - - if curr.is_closed() { - revert_to_idle!(); - return Ready(Err(AcquireError::closed())); - } - - let acquired = next.acquire_permits(num_permits, &self.stub); - - if !acquired { - // There are not enough available permits to satisfy the - // request. The permit transitions to a waiting state. - debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits); - - if let Some(waiter) = maybe_waiter.as_ref() { - // Safety: the caller owns the waiter. - let w = unsafe { waiter.as_ref() }; - w.set_permits_to_acquire(num_permits - curr.available_permits()); - } else { - // Get the waiter for the permit. - if let Some(waiter) = get_waiter() { - // Safety: the caller owns the waiter. - let w = unsafe { waiter.as_ref() }; - - // If there are any currently available permits, the - // waiter acquires those immediately and waits for the - // remaining permits to become available. - if !w.to_queued(num_permits - curr.available_permits()) { - // The node is alrady queued, there is no further work - // to do. - return Pending; - } - - maybe_waiter = Some(waiter); - } else { - // No waiter, this indicates the caller does not wish to - // "wait", so there is nothing left to do. - return Pending; - } - } - - next.set_waiter(maybe_waiter.unwrap()); - } - - debug_assert_ne!(curr.0, 0); - debug_assert_ne!(next.0, 0); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - if acquired { - // Successfully acquire permits **without** queuing the - // waiter node. The waiter node is not currently in the - // queue. - revert_to_idle!(); - return Ready(Ok(())); - } else { - // The node is pushed into the queue, the final step is - // to set the node's "next" pointer to return the wait - // queue into a consistent state. - - let prev_waiter = - curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub)); - - let waiter = maybe_waiter.unwrap(); - - // Link the nodes. - // - // Safety: the mpsc algorithm guarantees the old tail of - // the queue is not removed from the queue during the - // push process. - unsafe { - prev_waiter.as_ref().store_next(waiter); - } - - return Pending; - } - } - Err(actual) => { - curr = SemState(actual); - } - } - } - } - - /// Closes the semaphore. This prevents the semaphore from issuing new - /// permits and notifies all pending waiters. - pub(crate) fn close(&self) { - // Acquire the `rx_lock`, setting the "closed" flag on the lock. - let prev = self.rx_lock.fetch_or(1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(0, true); - } - /// Adds `n` new permits to the semaphore. - /// - /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. - pub(crate) fn add_permits(&self, n: usize) { - if n == 0 { - return; - } - - // TODO: Handle overflow. A panic is not sufficient, the process must - // abort. - let prev = self.rx_lock.fetch_add(n << 1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(n, false); - } - - fn add_permits_locked(&self, mut rem: usize, mut closed: bool) { - while rem > 0 || closed { - if closed { - SemState::fetch_set_closed(&self.state, AcqRel); - } - - // Release the permits and notify - self.add_permits_locked2(rem, closed); - - let n = rem << 1; - - let actual = if closed { - let actual = self.rx_lock.fetch_sub(n | 1, AcqRel); - closed = false; - actual - } else { - let actual = self.rx_lock.fetch_sub(n, AcqRel); - closed = actual & 1 == 1; - actual - }; - - rem = (actual >> 1) - rem; - } - } - - /// Releases a specific amount of permits to the semaphore - /// - /// This function is called by `add_permits` after the add lock has been - /// acquired. - fn add_permits_locked2(&self, mut n: usize, closed: bool) { - // If closing the semaphore, we want to drain the entire queue. The - // number of permits being assigned doesn't matter. - if closed { - n = usize::MAX; - } - - 'outer: while n > 0 { - unsafe { - let mut head = self.head.with(|head| *head); - let mut next_ptr = head.as_ref().next.load(Acquire); - - let stub = self.stub(); - - if head == stub { - // The stub node indicates an empty queue. Any remaining - // permits get assigned back to the semaphore. - let next = match NonNull::new(next_ptr) { - Some(next) => next, - None => { - // This loop is not part of the standard intrusive mpsc - // channel algorithm. This is where we atomically pop - // the last task and add `n` to the remaining capacity. - // - // This modification to the pop algorithm works because, - // at this point, we have not done any work (only done - // reading). We have a *pretty* good idea that there is - // no concurrent pusher. - // - // The capacity is then atomically added by doing an - // AcqRel CAS on `state`. The `state` cell is the - // linchpin of the algorithm. - // - // By successfully CASing `head` w/ AcqRel, we ensure - // that, if any thread was racing and entered a push, we - // see that and abort pop, retrying as it is - // "inconsistent". - let mut curr = SemState::load(&self.state, Acquire); - - loop { - if curr.has_waiter(&self.stub) { - // A waiter is being added concurrently. - // This is the MPSC queue's "inconsistent" - // state and we must loop and try again. - thread::yield_now(); - continue 'outer; - } - - // If closing, nothing more to do. - if closed { - debug_assert!(curr.is_closed(), "state = {:?}", curr); - return; - } - - let mut next = curr; - next.release_permits(n, &self.stub); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return, - Err(actual) => { - curr = SemState(actual); - } - } - } - } - }; - - self.head.with_mut(|head| *head = next); - head = next; - next_ptr = next.as_ref().next.load(Acquire); - } - - // `head` points to a waiter assign permits to the waiter. If - // all requested permits are satisfied, then we can continue, - // otherwise the node stays in the wait queue. - if !head.as_ref().assign_permits(&mut n, closed) { - assert_eq!(n, 0); - return; - } - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - self.remove_queued(head, closed); - continue 'outer; - } - - let state = SemState::load(&self.state, Acquire); - - // This must always be a pointer as the wait list is not empty. - let tail = state.waiter().unwrap(); - - if tail != head { - // Inconsistent - thread::yield_now(); - continue 'outer; - } - - self.push_stub(closed); - - next_ptr = head.as_ref().next.load(Acquire); - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - self.remove_queued(head, closed); - continue 'outer; - } - - // Inconsistent state, loop - thread::yield_now(); - } - } - } - - /// The wait node has had all of its permits assigned and has been removed - /// from the wait queue. - /// - /// Attempt to remove the QUEUED bit from the node. If additional permits - /// are concurrently requested, the node must be pushed back into the wait - /// queued. - fn remove_queued(&self, waiter: NonNull, closed: bool) { - let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire)); - - loop { - if curr.is_dropped() { - // The Permit dropped, it is on us to release the memory - let _ = unsafe { Box::from_raw(waiter.as_ptr()) }; - return; - } - - // The node is removed from the queue. We attempt to unset the - // queued bit, but concurrently the waiter has requested more - // permits. When the waiter requested more permits, it saw the - // queued bit set so took no further action. This requires us to - // push the node back into the queue. - if curr.permits_to_acquire() > 0 { - // More permits are requested. The waiter must be re-queued - unsafe { - self.push_waiter(waiter, closed); - } - return; - } - - let mut next = curr; - next.unset_queued(); - - let w = unsafe { waiter.as_ref() }; - - match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return, - Err(actual) => { - curr = WaiterState(actual); - } - } - } - } - - unsafe fn push_stub(&self, closed: bool) { - self.push_waiter(self.stub(), closed); - } - - unsafe fn push_waiter(&self, waiter: NonNull, closed: bool) { - // Set the next pointer. This does not require an atomic operation as - // this node is not accessible. The write will be flushed with the next - // operation - waiter.as_ref().next.store(ptr::null_mut(), Relaxed); - - // Update the tail to point to the new node. We need to see the previous - // node in order to update the next pointer as well as release `task` - // to any other threads calling `push`. - let next = SemState::new_ptr(waiter, closed); - let prev = SemState(self.state.swap(next.0, AcqRel)); - - debug_assert_eq!(closed, prev.is_closed()); - - // This function is only called when there are pending tasks. Because of - // this, the state must *always* be in pointer mode. - let prev = prev.waiter().unwrap(); - - // No cycles plz - debug_assert_ne!(prev, waiter); - - // Release `task` to the consume end. - prev.as_ref().next.store(waiter.as_ptr(), Release); - } - - fn stub(&self) -> NonNull { - unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) } - } -} - -impl Drop for Semaphore { - fn drop(&mut self) { - self.close(); - } -} - -impl fmt::Debug for Semaphore { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Semaphore") - .field("state", &SemState::load(&self.state, Relaxed)) - .field("head", &self.head.with(|ptr| ptr)) - .field("rx_lock", &self.rx_lock.load(Relaxed)) - .field("stub", &self.stub) - .finish() - } -} - -unsafe impl Send for Semaphore {} -unsafe impl Sync for Semaphore {} - -// ===== impl Permit ===== - -impl Permit { - /// Creates a new `Permit`. - /// - /// The permit begins in the "unacquired" state. - pub(crate) fn new() -> Permit { - use PermitState::Acquired; - - Permit { - waiter: None, - state: Acquired(0), - } - } - - /// Returns `true` if the permit has been acquired - #[allow(dead_code)] // may be used later - pub(crate) fn is_acquired(&self) -> bool { - match self.state { - PermitState::Acquired(num) if num > 0 => true, - _ => false, - } - } - - /// Tries to acquire the permit. If no permits are available, the current task - /// is notified once a new permit becomes available. - pub(crate) fn poll_acquire( - &mut self, - cx: &mut Context<'_>, - num_permits: u16, - semaphore: &Semaphore, - ) -> Poll> { - use std::cmp::Ordering::*; - use PermitState::*; - - match self.state { - Waiting(requested) => { - // There must be a waiter - let waiter = self.waiter.as_ref().unwrap(); - - match requested.cmp(&num_permits) { - Less => { - let delta = num_permits - requested; - - // Request additional permits. If the waiter has been - // dequeued, it must be re-queued. - if !waiter.try_inc_permits_to_acquire(delta as usize) { - let waiter = NonNull::from(&**waiter); - - // Ignore the result. The check for - // `permits_to_acquire()` will converge the state as - // needed - let _ = semaphore.poll_acquire2(delta, || Some(waiter))?; - } - - self.state = Waiting(num_permits); - } - Greater => { - let delta = requested - num_permits; - let to_release = waiter.try_dec_permits_to_acquire(delta as usize); - - semaphore.add_permits(to_release); - self.state = Waiting(num_permits); - } - Equal => {} - } - - if waiter.permits_to_acquire()? == 0 { - self.state = Acquired(requested); - return Ready(Ok(())); - } - - waiter.waker.register_by_ref(cx.waker()); - - if waiter.permits_to_acquire()? == 0 { - self.state = Acquired(requested); - return Ready(Ok(())); - } - - Pending - } - Acquired(acquired) => { - if acquired >= num_permits { - Ready(Ok(())) - } else { - match semaphore.poll_acquire(cx, num_permits - acquired, self)? { - Ready(()) => { - self.state = Acquired(num_permits); - Ready(Ok(())) - } - Pending => { - self.state = Waiting(num_permits); - Pending - } - } - } - } - } - } - - /// Tries to acquire the permit. - pub(crate) fn try_acquire( - &mut self, - num_permits: u16, - semaphore: &Semaphore, - ) -> Result<(), TryAcquireError> { - use PermitState::*; - - match self.state { - Waiting(requested) => { - // There must be a wai