summaryrefslogtreecommitdiffstats
path: root/tokio/src/sync
diff options
context:
space:
mode:
authorCarl Lerche <me@carllerche.com>2019-12-18 10:13:15 -0800
committerGitHub <noreply@github.com>2019-12-18 10:13:15 -0800
commit7c010ed030d0084d38451bbebf727c6695c2dbac (patch)
treed2507c71f9b6c1cdb988097592daa9566ec2ec51 /tokio/src/sync
parent42c942de1433c37946f81d185459b8ee54928d3d (diff)
sync: add broadcast channel (#1943)
Adds a broadcast channel implementation. A broadcast channel is a multi-producer, multi-consumer channel where each consumer receives a clone of every value sent. This is useful for implementing pub / sub style patterns. Implemented as a ring buffer, a Vec of the specified capacity is allocated on initialization of the channel. Values are pushed into slots. When the channel is full, a send overwrites the oldest value. Receivers detect this and return an error on the next call to receive. This prevents unbounded buffering and does not make the channel vulnerable to the slowest consumer. Closes: #1585
Diffstat (limited to 'tokio/src/sync')
-rw-r--r--tokio/src/sync/broadcast.rs989
-rw-r--r--tokio/src/sync/mod.rs2
-rw-r--r--tokio/src/sync/tests/loom_broadcast.rs144
-rw-r--r--tokio/src/sync/tests/mod.rs1
4 files changed, 1136 insertions, 0 deletions
diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs
new file mode 100644
index 00000000..093aa1f3
--- /dev/null
+++ b/tokio/src/sync/broadcast.rs
@@ -0,0 +1,989 @@
+//! A multi-producer, multi-consumer broadcast queue. Each sent value is seen by
+//! all consumers.
+//!
+//! A [`Sender`] is used to broadcast values to **all** connected [`Receiver`]
+//! values. Both [`Sender`] and [`Receiver`] handles are clone-able, allowing
+//! concurrent send and receive actions. Additionally, [`Sender`] is `Sync`.
+//!
+//! When a value is sent, **all** [`Receiver`] handles are notified and will
+//! receive the value. The value is stored once inside the channel and cloned on
+//! demand for each receiver. Once all receivers have received a clone of the
+//! value, the value is released from the channel.
+//!
+//! A channel is created by calling [`channel`], specifying the maximum number
+//! of messages the channel can retain at any given time.
+//!
+//! New [`Receiver`] handles are created by calling [`Sender::subscribe`]. The
+//! returned [`Receiver`] will receive values sent **after** the call to
+//! `subscribe`.
+//!
+//! ## Lagging
+//!
+//! As sent messages must be retained until **all** [`Receiver`] handles receive
+//! a clone, broadcast channels are suspectible to the "slow receiver" problem.
+//! In this case, all but one receiver are able to receive values at the rate
+//! they are sent. Because one receiver is stalled, the channel starts to fill
+//! up.
+//!
+//! This broadcast channel implementation handles this case by setting a hard
+//! upper bound on the number of values the channel may retain at any given
+//! time. This upper bound is passed to the [`channel`] function as an argument.
+//!
+//! If a value is sent when the channel is at capacity, the oldest value
+//! currently held by the channel is released. This frees up space for the new
+//! value. Any receiver that has not yet seen the released value will return
+//! [`RecvError::Lagged`] the next time [`recv`] is called.
+//!
+//! Once [`RecvError::Lagged`] is returned, the lagging receiver's position is
+//! updated to the oldest value contained by the channel. The next call to
+//! [`recv`] will return this value.
+//!
+//! This behavior enables a receiver to detect when it has lagged so far behind
+//! that data has been dropped. The caller may decide how to respond to this:
+//! either by aborting its task or by tolerating lost messages and resuming
+//! consumption of the channel.
+//!
+//! ## Closing
+//!
+//! When **all** [`Sender`] handles have been dropped, no new values may be
+//! sent. At this point, the channel is "closed". Once a receiver has received
+//! all values retained by the channel, the next call to [`recv`] will return
+//! with [`RecvError::Closed`].
+//!
+//! [`Sender`]: crate::sync::broadcast::Sender
+//! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe
+//! [`Receiver`]: crate::sync::broadcast::Receiver
+//! [`channel`]: crate::sync::broadcast::channel
+//! [`RecvError::Lagged`]: crate::sync::broadcast::RecvError::Lagged
+//! [`RecvError::Closed`]: crate::sync::broadcast::RecvError::Closed
+//! [`recv`]: crate::sync::broadcast::Receiver::recv
+//!
+//! # Examples
+//!
+//! Basic usage
+//!
+//! ```
+//! use tokio::sync::broadcast;
+//!
+//! #[tokio::main]
+//! async fn main() {
+//! let (tx, mut rx1) = broadcast::channel(16);
+//! let mut rx2 = tx.subscribe();
+//!
+//! tokio::spawn(async move {
+//! assert_eq!(rx1.recv().await.unwrap(), 10);
+//! assert_eq!(rx1.recv().await.unwrap(), 20);
+//! });
+//!
+//! tokio::spawn(async move {
+//! assert_eq!(rx2.recv().await.unwrap(), 10);
+//! assert_eq!(rx2.recv().await.unwrap(), 20);
+//! });
+//!
+//! tx.send(10).unwrap();
+//! tx.send(20).unwrap();
+//! }
+//! ```
+//!
+//! Handling lag
+//!
+//! ```
+//! use tokio::sync::broadcast;
+//!
+//! #[tokio::main]
+//! async fn main() {
+//! let (tx, mut rx) = broadcast::channel(2);
+//!
+//! tx.send(10).unwrap();
+//! tx.send(20).unwrap();
+//! tx.send(30).unwrap();
+//!
+//! // The receiver lagged behind
+//! assert!(rx.recv().await.is_err());
+//!
+//! // At this point, we can abort or continue with lost messages
+//!
+//! assert_eq!(20, rx.recv().await.unwrap());
+//! assert_eq!(30, rx.recv().await.unwrap());
+//! }
+
+use crate::loom::cell::CausalCell;
+use crate::loom::future::AtomicWaker;
+use crate::loom::sync::{Mutex, Arc, Condvar};
+use crate::loom::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, spin_loop_hint};
+
+use std::fmt;
+use std::ptr;
+use std::sync::atomic::Ordering::SeqCst;
+use std::task::{Context, Poll, Waker};
+use std::usize;
+
+/// Sending-half of the [`broadcast`] channel.
+///
+/// May be used from many threads. Messages can be sent with
+/// [`send`][Sender::send].
+///
+/// # Examples
+///
+/// ```
+/// use tokio::sync::broadcast;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let (tx, mut rx1) = broadcast::channel(16);
+/// let mut rx2 = tx.subscribe();
+///
+/// tokio::spawn(async move {
+/// assert_eq!(rx1.recv().await.unwrap(), 10);
+/// assert_eq!(rx1.recv().await.unwrap(), 20);
+/// });
+///
+/// tokio::spawn(async move {
+/// assert_eq!(rx2.recv().await.unwrap(), 10);
+/// assert_eq!(rx2.recv().await.unwrap(), 20);
+/// });
+///
+/// tx.send(10).unwrap();
+/// tx.send(20).unwrap();
+/// }
+/// ```
+///
+/// [`broadcast`]: crate::sync::broadcast
+pub struct Sender<T> {
+ shared: Arc<Shared<T>>,
+}
+
+/// Receiving-half of the [`broadcast`] channel.
+///
+/// Must not be used concurrently. Messages may be retrieved using
+/// [`recv`][Receiver::recv].
+///
+/// # Examples
+///
+/// ```
+/// use tokio::sync::broadcast;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let (tx, mut rx1) = broadcast::channel(16);
+/// let mut rx2 = tx.subscribe();
+///
+/// tokio::spawn(async move {
+/// assert_eq!(rx1.recv().await.unwrap(), 10);
+/// assert_eq!(rx1.recv().await.unwrap(), 20);
+/// });
+///
+/// tokio::spawn(async move {
+/// assert_eq!(rx2.recv().await.unwrap(), 10);
+/// assert_eq!(rx2.recv().await.unwrap(), 20);
+/// });
+///
+/// tx.send(10).unwrap();
+/// tx.send(20).unwrap();
+/// }
+/// ```
+///
+/// [`broadcast`]: crate::sync::broadcast
+pub struct Receiver<T> {
+ /// State shared with all receivers and senders.
+ shared: Arc<Shared<T>>,
+
+ /// Next position to read from
+ next: u64,
+
+ /// Waiter state
+ wait: Arc<WaitNode>,
+}
+
+/// Error returned by [`Sender::send`][Sender::send].
+///
+/// A **send** operation can only fail if there are no active receivers,
+/// implying that the message could never be received. The error contains the
+/// message being sent as a payload so it can be recovered.
+#[derive(Debug)]
+pub struct SendError<T>(pub T);
+
+/// An error returned from the [`recv`] function on a [`Receiver`].
+///
+/// [`recv`]: crate::sync::broadcast::Receiver::recv
+/// [`Receiver`]: crate::sync::broadcast::Receiver
+#[derive(Debug)]
+pub enum RecvError {
+ /// There are no more active senders implying no further messages will ever
+ /// be sent.
+ Closed,
+
+ /// The receiver lagged too far behind and has been forcibly disconnected.
+ /// Attempting to receive again will return the oldest message still
+ /// retained by the channel.
+ ///
+ /// Includes the number of skipped messages.
+ Lagged(u64),
+}
+
+/// An error returned from the [`try_recv`] function on a [`Receiver`].
+///
+/// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv
+/// [`Receiver`]: crate::sync::broadcast::Receiver
+#[derive(Debug)]
+pub enum TryRecvError {
+ /// The channel is currently empty. There are still active
+ /// [`Sender`][Sender] handles, so data may yet become available.
+ Empty,
+
+ /// There are no more active senders implying no further messages will ever
+ /// be sent.
+ Closed,
+
+ /// The receiver lagged too far behind and has been forcibly disconnected.
+ /// Attempting to receive again will return the oldest message still
+ /// retained by the channel.
+ ///
+ /// Includes the number of skipped messages.
+ Lagged(u64),
+}
+
+/// Data shared between senders and receivers
+struct Shared<T> {
+ /// slots in the channel
+ buffer: Box<[Slot<T>]>,
+
+ /// Mask a position -> index
+ mask: usize,
+
+ /// Tail of the queue
+ tail: Mutex<Tail>,
+
+ /// Notifies a sender that the slot is unlocked
+ condvar: Condvar,
+
+ /// Stack of pending waiters
+ wait_stack: AtomicPtr<WaitNode>,
+
+ /// Number of outstanding Sender handles
+ num_tx: AtomicUsize,
+}
+
+/// Next position to write a value
+struct Tail {
+ /// Next position to write to
+ pos: u64,
+
+ /// Number of active receivers
+ rx_cnt: usize,
+}
+
+/// Node in the linked list
+struct Slot<T> {
+ /// Remaining numer of senders that are expected to see this value.
+ ///
+ /// When this goes to zero, the value is released.
+ rem: AtomicUsize,
+
+ /// Used to lock the `write` field.
+ lock: AtomicUsize,
+
+ /// The value being broadcast
+ ///
+ /// Synchronized by `state`
+ write: Write<T>,
+}
+
+/// A write in the buffer
+struct Write<T> {
+ /// Uniquely identifies this write
+ pos: CausalCell<u64>,
+
+ /// The written value
+ val: CausalCell<Option<T>>,
+}
+
+/// Tracks a waiting receiver
+#[derive(Debug)]
+struct WaitNode {
+ /// True if queued
+ queued: AtomicBool,
+
+ /// Task to wake when a permit is made available.
+ waker: AtomicWaker,
+
+ /// Next pointer in the stack of waiting senders.
+ next: CausalCell<*const WaitNode>,
+}
+
+struct RecvGuard<'a, T> {
+ slot: &'a Slot<T>,
+ condvar: &'a Condvar,
+}
+
+/// Max number of receivers. Reserve space to lock.
+const MAX_RECEIVERS: usize = usize::MAX >> 1;
+
+/// Create a bounded, multi-producer, multi-consumer channel where each sent
+/// value is broadcasted to all active receivers.
+///
+/// All data sent on [`Sender`] will become available on every active
+/// [`Receiver`] in the same order as it was sent.
+///
+/// The `Sender` can be cloned to `send` to the same channel from multiple
+/// points in the process or it can be used concurrently from an `Arc`. New
+/// `Receiver` handles are created by calling [`Sender::subscribe`].
+///
+/// If all [`Receiver`] handles are dropped, the `send` method will return a
+/// [`SendError`]. Similarly, if all [`Sender`] handles are dropped, the [`recv`]
+/// method will return a [`RecvError`].
+///
+/// [`Sender`]: crate::sync::broadcast::Sender
+/// [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe
+/// [`Receiver`]: crate::sync::broadcast::Receiver
+/// [`recv`]: crate::sync::broadcast::Receiver::recv
+/// [`SendError`]: crate::sync::broadcast::SendError
+/// [`RecvError`]: crate::sync::broadcast::RecvError
+///
+/// # Examples
+///
+/// ```
+/// use tokio::sync::broadcast;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let (tx, mut rx1) = broadcast::channel(16);
+/// let mut rx2 = tx.subscribe();
+///
+/// tokio::spawn(async move {
+/// assert_eq!(rx1.recv().await.unwrap(), 10);
+/// assert_eq!(rx1.recv().await.unwrap(), 20);
+/// });
+///
+/// tokio::spawn(async move {
+/// assert_eq!(rx2.recv().await.unwrap(), 10);
+/// assert_eq!(rx2.recv().await.unwrap(), 20);
+/// });
+///
+/// tx.send(10).unwrap();
+/// tx.send(20).unwrap();
+/// }
+/// ```
+pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
+ assert!(capacity > 0, "capacity is empty");
+ assert!(capacity <= usize::MAX >> 1, "requested capacity too large");
+
+ // Round to a power of two
+ capacity = capacity.next_power_of_two();
+
+ let mut buffer = Vec::with_capacity(capacity);
+
+ for i in 0..capacity {
+ buffer.push(Slot {
+ rem: AtomicUsize::new(0),
+ lock: AtomicUsize::new(0),
+ write: Write {
+ pos: CausalCell::new((i as u64).wrapping_sub(capacity as u64)),
+ val: CausalCell::new(None),
+ },
+ });
+ }
+
+ let shared = Arc::new(Shared {
+ buffer: buffer.into_boxed_slice(),
+ mask: capacity - 1,
+ tail: Mutex::new(Tail {
+ pos: 0,
+ rx_cnt: 1,
+ }),
+ condvar: Condvar::new(),
+ wait_stack: AtomicPtr::new(ptr::null_mut()),
+ num_tx: AtomicUsize::new(1),
+ });
+
+ let rx = Receiver {
+ shared: shared.clone(),
+ next: 0,
+ wait: Arc::new(WaitNode {
+ queued: AtomicBool::new(false),
+ waker: AtomicWaker::new(),
+ next: CausalCell::new(ptr::null()),
+ }),
+ };
+
+ let tx = Sender {
+ shared,
+ };
+
+ (tx, rx)
+}
+
+unsafe impl<T: Send> Send for Sender<T> {}
+unsafe impl<T: Send> Sync for Sender<T> {}
+
+unsafe impl<T: Send> Send for Receiver<T> {}
+unsafe impl<T: Send> Sync for Receiver<T> {}
+
+impl<T> Sender<T> {
+ /// Attempts to send a value to all active [`Receiver`] handles, returning
+ /// it back if it could not be sent.
+ ///
+ /// A successful send occurs when there is at least one active [`Receiver`]
+ /// handle. An unsuccessful send would be one where all associated
+ /// [`Receiver`] handles have already been dropped.
+ ///
+ /// # Return
+ ///
+ /// On success, the number of subscribed [`Receiver`] handles is returned.
+ /// This does not mean that this number of receivers will see the message as
+ /// a receiver may drop before receiving the message.
+ ///
+ /// # Note
+ ///
+ /// A return value of `Ok` **does not** mean that the sent value will be
+ /// observed by all or any of the active [`Receiver`] handles. [`Receiver`]
+ /// handles may be dropped before receiving the sent message.
+ ///
+ /// A return value of `Err` **does not** mean that future calls to `send`
+ /// will fail. New [`Receiver`] handles may be created by calling
+ /// [`subscribe`].
+ ///
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ /// [`subscribe`]: crate::sync::broadcast::Sender::subscribe
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx1) = broadcast::channel(16);
+ /// let mut rx2 = tx.subscribe();
+ ///
+ /// tokio::spawn(async move {
+ /// assert_eq!(rx1.recv().await.unwrap(), 10);
+ /// assert_eq!(rx1.recv().await.unwrap(), 20);
+ /// });
+ ///
+ /// tokio::spawn(async move {
+ /// assert_eq!(rx2.recv().await.unwrap(), 10);
+ /// assert_eq!(rx2.recv().await.unwrap(), 20);
+ /// });
+ ///
+ /// tx.send(10).unwrap();
+ /// tx.send(20).unwrap();
+ /// }
+ /// ```
+ pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
+ self.send2(Some(value))
+ .map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap()))
+ }
+
+ /// Create a new [`Receiver`] handle that will receive values sent **after**
+ /// this call to `subscribe`.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, _rx) = broadcast::channel(16);
+ ///
+ /// // Will not be seen
+ /// tx.send(10).unwrap();
+ ///
+ /// let mut rx = tx.subscribe();
+ ///
+ /// tx.send(20).unwrap();
+ ///
+ /// let value = rx.recv().await.unwrap();
+ /// assert_eq!(20, value);
+ /// }
+ /// ```
+ pub fn subscribe(&self) -> Receiver<T> {
+ let shared = self.shared.clone();
+
+ let mut tail = shared.tail.lock().unwrap();
+
+ if tail.rx_cnt == MAX_RECEIVERS {
+ panic!("max receivers");
+ }
+
+ tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
+ let next = tail.pos;
+
+ drop(tail);
+
+ Receiver {
+ shared,
+ next,
+ wait: Arc::new(WaitNode {
+ queued: AtomicBool::new(false),
+ waker: AtomicWaker::new(),
+ next: CausalCell::new(ptr::null()),
+ }),
+ }
+ }
+
+ /// Returns the number of active receivers
+ ///
+ /// An active receiver is a [`Receiver`] handle returned from [`channel`] or
+ /// [`subscribe`]. These are the handles that will receive values sent on
+ /// this [`Sender`].
+ ///
+ /// # Note
+ ///
+ /// It is not guaranteed that a sent message will reach this number of
+ /// receivers. Active receivers may never call [`recv`] again before
+ /// dropping.
+ ///
+ /// [`recv`]: crate::sync::broadcast::Receiver::recv
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ /// [`Sender`]: crate::sync::broadcast::Sender
+ /// [`subscribe`]: crate::sync::broadcast::Sender::subscribe
+ /// [`channel`]: crate::sync::broadcast::channel
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, _rx1) = broadcast::channel(16);
+ ///
+ /// assert_eq!(1, tx.receiver_count());
+ ///
+ /// let mut _rx2 = tx.subscribe();
+ ///
+ /// assert_eq!(2, tx.receiver_count());
+ ///
+ /// tx.send(10).unwrap();
+ /// }
+ /// ```
+ pub fn receiver_count(&self) -> usize {
+ let tail = self.shared.tail.lock().unwrap();
+ tail.rx_cnt
+ }
+
+ fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> {
+ let mut tail = self.shared.tail.lock().unwrap();
+
+ if tail.rx_cnt == 0 {
+ return Err(SendError(value));
+ }
+
+ // Position to write into
+ let pos = tail.pos;
+ let rem = tail.rx_cnt;
+ let idx = (pos & self.shared.mask as u64) as usize;
+
+ // Update the tail position
+ tail.pos = tail.pos.wrapping_add(1);
+
+ // Get the slot
+ let slot = &self.shared.buffer[idx];
+
+ // Acquire the write lock
+ let mut prev = slot.lock.fetch_add(1, SeqCst);
+
+ while prev & !1 != 0 {
+ // Concurrent readers, we must go to sleep
+ tail = self.shared.condvar.wait(tail).unwrap();
+
+ prev = slot.lock.load(SeqCst);
+ }
+
+ // Release the mutex
+ drop(tail);
+
+ // Slot lock acquired
+ slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos });
+ slot.write.val.with_mut(|ptr| unsafe { *ptr = value });
+
+ // Set remaining receivers
+ slot.rem.store(rem, SeqCst);
+
+ // Release the slot lock
+ slot.lock.store(0, SeqCst);
+
+ // Notify waiting receivers
+ self.notify_rx();
+
+ Ok(rem)
+ }
+
+ fn notify_rx(&self) {
+ let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode;
+
+ while !curr.is_null() {
+ let waiter = unsafe { Arc::from_raw(curr) };
+
+ // Update `curr` before toggling `queued` and waking
+ curr = waiter.next.with(|ptr| unsafe { *ptr });
+
+ // Unset queued
+ waiter.queued.store(false, SeqCst);
+
+ // Wake
+ waiter.waker.wake();
+ }
+ }
+}
+
+impl<T> Clone for Sender<T> {
+ fn clone(&self) -> Sender<T> {
+ let shared = self.shared.clone();
+ shared.num_tx.fetch_add(1, SeqCst);
+
+ Sender { shared }
+ }
+}
+
+impl<T> Drop for Sender<T> {
+ fn drop(&mut self) {
+ if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) {
+ let _ = self.send2(None);
+ }
+ }
+}
+
+impl<T> Receiver<T> {
+ /// Lock the next value if there is one.
+ ///
+ /// The caller is responsible for unlocking
+ fn recv_ref(&mut self, spin: bool) -> Result<RecvGuard<'_, T>, TryRecvError> {
+ let idx = (self.next & self.shared.mask as u64) as usize;
+
+ // The slot holding the next value to read
+ let slot = &self.shared.buffer[idx];
+
+ // Lock the slot
+ if !slot.try_rx_lock() {
+ if spin {
+ while !slot.try_rx_lock() {
+ spin_loop_hint();
+ }
+ } else {
+ return Err(TryRecvError::Empty);
+ }
+ }
+
+ let guard = RecvGuard {
+ slot,
+ condvar: &self.shared.condvar,
+ };
+
+ if guard.pos() != self.next {
+ let pos = guard.pos();
+
+ guard.drop_no_rem_dec();
+
+ if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next {
+ return Err(TryRecvError::Empty);
+ } else {
+ let tail = self.shared.tail.lock().unwrap();
+
+ // `tail.pos` points to the slot the **next** send writes to.
+ // Because a receiver is lagging, this slot also holds the
+ // oldest value. To make the positions match, we subtract the
+ // capacity.
+ let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64);
+ let missed = next.wrapping_sub(self.next);
+
+ self.next = next;
+
+ return Err(TryRecvError::Lagged(missed));
+ }
+ }
+
+ self.next = self.next.wrapping_add(1);
+
+ Ok(guard)
+ }
+}
+
+impl<T> Receiver<T>
+where
+ T: Clone,
+{
+ /// Attempts to return a pending value on this receiver without awaiting.
+ ///
+ /// This is useful for a flavor of "optimistic check" before deciding to
+ /// await on a receiver.
+ ///
+ /// Compared with [`recv`], this function has three failure cases instead of one
+ /// (one for closed, one for an empty buffer, one for a lagging receiver).
+ ///
+ /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have
+ /// dropped, indicating that no further values can be sent on the channel.
+ ///
+ /// If the [`Receiver`] handle falls behind, once the channel is full, newly
+ /// sent values will overwrite old values. At this point, a call to [`recv`]
+ /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s
+ /// internal cursor is updated to point to the oldest value still held by
+ /// the channel. A subsequent call to [`try_recv`] will return this value
+ /// **unless** it has been since overwritten. If there are no values to
+ /// receive, `Err(TryRecvError::Empty)` is returned.
+ ///
+ /// [`recv`]: crate::sync::broadcast::Receiver::recv
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = broadcast::channel(16);
+ ///
+ /// assert!(rx.try_recv().is_err());
+ ///
+ /// tx.send(10).unwrap();
+ ///
+ /// let value = rx.try_recv().unwrap();
+ /// assert_eq!(10, value);
+ /// }
+ /// ```
+ pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
+ let guard = self.recv_ref(false)?;
+ guard.clone_value().ok_or(TryRecvError::Closed)
+ }
+
+ #[doc(hidden)] // TODO: document
+ pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
+ if let Some(value) = ok_empty(self.try_recv())? {
+ return Poll::Ready(Ok(value));
+ }
+
+ self.register_waker(cx.waker());
+
+ if let Some(value) = ok_empty(self.try_recv())? {
+ Poll::Ready(Ok(value))
+ } else {
+ Poll::Pending
+ }
+ }
+
+ /// Receive the next value for this receiver.
+ ///
+ /// Each [`Receiver`] handle will receive a clone of all values sent
+ /// **after** it has subscribed.
+ ///
+ /// `Err(RecvError::Closed)` is returned when all `Sender` halves have
+ /// dropped, indicating that no further values can be sent on the channel.
+ ///
+ /// If the [`Receiver`] handle falls behind, once the channel is full, newly
+ /// sent values will overwrite old values. At this point, a call to [`recv`]
+ /// will return with `Err(RecvError::Lagged)` and the [`Receiver`]'s
+ /// internal cursor is updated to point to the oldest value still held by
+ /// the channel. A subsequent call to [`recv`] will return this value
+ /// **unless** it has been since overwritten.
+ ///
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ /// [`recv`]: crate::sync::broadcast::Receiver::recv
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx1) = broadcast::channel(16);
+ /// let mut rx2 = tx.subscribe();
+ ///
+ /// tokio::spawn(async move {
+ /// assert_eq!(rx1.recv().await.unwrap(), 10);
+ /// assert_eq!(rx1.recv().await.unwrap(), 20);
+ /// });
+ ///
+ /// tokio::spawn(async move {
+ /// assert_eq!(rx2.recv().await.unwrap(), 10);
+ /// assert_eq!(rx2.recv().await.unwrap(), 20);
+ /// });
+ ///
+ /// tx.send(10).unwrap();
+ /// tx.send(20).unwrap();
+ /// }
+ /// ```
+ ///
+ /// Handling lag
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = broadcast::channel(2);
+ ///
+ /// tx.send(10).unwrap();
+ /// tx.send(20).unwrap();
+ /// tx.send(30).unwrap();
+ ///
+ /// // The receiver lagged behind
+ /// assert!(rx.recv().await.is_err());
+ ///
+ /// // At this point, we can abort or continue with lost messages
+ ///
+ /// assert_eq!(20, rx.recv().await.unwrap());
+ /// assert_eq!(30, rx.recv().await.unwrap());
+ /// }
+ pub async fn recv(&mut self) -> Result<T, RecvError> {
+ use crate::future::poll_fn;
+
+ poll_fn(|cx| self.poll_recv(cx)).await
+ }
+
+ fn register_waker(&self, cx: &Waker) {
+ self.wait.waker.register_by_ref(cx);
+
+ if !self.wait.queued.load(SeqCst) {
+ // Set `queued` before queuing.
+ self.wait.queued.store(true, SeqCst);
+
+ let mut curr = self.shared.wait_stack.load(SeqCst);
+
+ // The ref count is decremented in `notify_rx` when all nodes are
+ // removed from the waiter stack.
+ let node = Arc::into_raw(self.wait.clone()) as *mut _;
+
+ loop {
+ // Safety: `queued == false` means the caller has exclusive
+ // access to `self.wait.next`.
+ self.wait.next.with_mut(|ptr| unsafe { *ptr = curr });
+
+ let res = self.shared.wait_stack.compare_exchange(curr, node, SeqCst, SeqCst);
+
+ match res {
+ Ok(_) => return,
+ Err(actual) => curr = actual,
+ }
+ }
+ }
+ }
+}
+
+impl<T> Drop for Receiver<T> {
+ fn drop(&mut self) {
+ let mut tail = self.shared.tail.lock().unwrap();
+
+ tail.rx_cnt -= 1;
+ let until = tail.pos;
+
+ drop(tail);
+
+ while self.next != until {
+ match self.recv_ref(true) {
+ // Ignore the value
+ Ok(_) => {}
+ // The channel is closed
+ Err(TryRecvError::Closed) => break,
+ // Ignore lagging, we will catch up
+ Err(TryRecvError::Lagged(..)) => {}
+ // Can't be empty
+ Err(TryRecvError::Empty) => panic!("unexpected empty broadcast channel"),
+ }
+ }
+ }
+}
+
+impl<T> Drop for Shared<T> {
+ fn drop(&mut self) {
+ // Clear the wait stack
+ let mut curr = *self.wait_stack.get_mut() as *const WaitNode;
+
+ while !curr.is_null() {
+ let waiter = unsafe { Arc::from_raw(curr) };
+ curr = waiter.next.with(|ptr| unsafe { *ptr });
+ }
+ }
+}
+
+impl<T> fmt::Debug for Sender<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(fmt, "broadcast::Sender")
+ }
+}
+
+impl<T> fmt::Debug for Receiver<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(fmt, "broadcast::Receiver")
+ }
+}
+
+impl<T> Slot<T> {
+ /// Try to lock the slot for a receiver. If `false`, then a sender holds the
+ /// lock and the calling task will be notified once the sender has released
+ /// the lock.
+ fn try_rx_lock(&self) -> bool {
+ let mut curr = self.lock.load(SeqCst);
+
+ loop {
+ if curr & 1 == 1 {
+ // Locked by sender
+ return false;
+ }
+
+ // Only increment (by 2) if the LSB "lock" bit is not set.
+ let res = self.lock.compare_exchange(curr, curr + 2, SeqCst, SeqCst);
+
+ match res {
+ Ok(_) => return true,
+ Err(actual) => curr = actual,
+ }
+ }
+ }
+
+ fn rx_unlock(&self, condvar: &Condvar, rem_dec: bool) {
+ if rem_dec {
+ // Decrement the remaining counter
+ if 1 == self.rem.fetch_sub(1, SeqCst) {
+ // Last receiver, drop the value
+ self.write.val.with_mut(|ptr| unsafe { *ptr = None });
+ }
+ }
+
+ let prev = self.lock.fetch_sub(2, SeqCst);
+
+ if prev & 1 == 1 {
+ // Sender waiting for lock
+ condvar.notify_one();
+ }
+ }
+}
+
+impl<'a, T> RecvGuard<'a, T> {
+ fn pos(&self) -> u64 {
+ self.slot.write.pos.with(|ptr| unsafe { *ptr })
+ }
+
+ fn clone_value(&self) -> Option<T>
+ where
+ T: Clone,
+ {
+ self.slot.write.val.with(|ptr| unsafe { (*ptr).clone() })
+ }
+
+ fn drop_no_rem_dec(self) {
+ use std::mem;
+
+ self.slot.rx_unlock(self.condvar, false);
+
+ mem::forget(self);
+ }
+}
+
+impl<'a, T> Drop for RecvGuard<'a, T> {
+ fn drop(&mut self) {
+ self.slot.rx_unlock(self.condvar, true)
+ }
+}
+
+fn ok_empty<T>(res: Result<T, TryRecvError>) -> Result<Option<T>, RecvError> {
+ match res {
+ Ok(value) => Ok(Some(value)),
+ Err(TryRecvError::Empty) => Ok(None),
+ Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)),
+ Err(TryRecvError::Closed) => Err(RecvError::Closed),
+ }
+}
diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs
index a05accbf..6ed91607 100644
--- a/tokio/src/sync/mod.rs
+++ b/tokio/src/sync/mod.rs
@@ -19,6 +19,8 @@ cfg_sync! {
mod barrier;
pub use barrier::{Barrier, BarrierWaitResult};
+ pub mod broadcast;
+
pub mod mpsc;
mod mutex;
diff --git a/tokio/src/sync/tests/loom_broadcast.rs b/tokio/src/sync/tests/loom_broadcast.rs
new file mode 100644
index 00000000..53f76a25
--- /dev/null
+++ b/tokio/src/sync/tests/loom_broadcast.rs
@@ -0,0 +1,144 @@
+use crate::sync::broadcast;
+use crate::sync::broadcast::RecvError::{Closed, Lagged};
+
+use loom::future::block_on;
+use loom::sync::Arc;
+use loom::thread;
+use tokio_test::{assert_err, assert_ok};
+
+// An `Arc` is used as the value in order to detect memory leaks.
+#[test]
+fn broadcast_two() {
+ loom::model(|| {
+ let (tx, mut rx1) = broadcast::channel::<Arc<&'static str>>(16);
+ let mut rx2 = tx.subscribe();
+
+ let th1 = thread::spawn(move || {
+ block_on(async {
+ let v = assert_ok!(rx1.recv().await);
+ assert_eq!(*v, "hello");
+
+ let v = assert_ok!(rx1.recv().await);
+ assert_eq!(*v, "world");
+
+ match assert_err!(rx1.recv().await) {
+ Closed => {}
+ _ => panic!(),
+ }
+ });
+ });
+
+ let th2 = thread::spawn(move || {
+ block_on(async {
+ let v = assert_ok!(rx2.recv().await);
+ assert_eq!(*v, "hello");
+
+ let v = assert_ok!(rx2.recv().await);
+ assert_eq!(*v, "world");
+
+ match assert_err!(rx2.recv().await) {
+ Closed => {}
+ _ => panic!(),
+ }
+ });
+ });
+
+ assert_ok!(tx.send(Arc::new("hello")));
+ assert_ok!(tx.send(Arc::new("world")));
+ drop(tx);
+