summaryrefslogtreecommitdiffstats
path: root/tokio/src/sync/broadcast.rs
diff options
context:
space:
mode:
authorCarl Lerche <me@carllerche.com>2020-05-12 15:09:43 -0700
committerGitHub <noreply@github.com>2020-05-12 15:09:43 -0700
commitfb7dfcf4322b5e60604815aea91266b88f0b7823 (patch)
treeaeba04a918be8a00eb09f6001a4f7946bd188c66 /tokio/src/sync/broadcast.rs
parenta32f918671ef641affbfcc4d4005ab738da795df (diff)
sync: use intrusive list strategy for broadcast (#2509)
Previously, in the broadcast channel, receiver wakers were passed to the sender via an atomic stack with allocated nodes. When a message was sent, the stack was drained. This caused a problem when many receivers pushed a waiter node then dropped. The waiter node remained indefinitely in cases where no values were sent. This patch switches broadcast to use the intrusive linked-list waiter strategy used by `Notify` and `Semaphore.
Diffstat (limited to 'tokio/src/sync/broadcast.rs')
-rw-r--r--tokio/src/sync/broadcast.rs506
1 files changed, 379 insertions, 127 deletions
diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs
index 9873dcb7..0c8716f7 100644
--- a/tokio/src/sync/broadcast.rs
+++ b/tokio/src/sync/broadcast.rs
@@ -109,12 +109,15 @@
//! }
use crate::loom::cell::UnsafeCell;
-use crate::loom::future::AtomicWaker;
-use crate::loom::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize};
+use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
+use crate::util::linked_list::{self, LinkedList};
use std::fmt;
-use std::ptr;
+use std::future::Future;
+use std::marker::PhantomPinned;
+use std::pin::Pin;
+use std::ptr::NonNull;
use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};
use std::usize;
@@ -192,8 +195,8 @@ pub struct Receiver<T> {
/// Next position to read from
next: u64,
- /// Waiter state
- wait: Arc<WaitNode>,
+ /// Used to support the deprecated `poll_recv` fn
+ waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>,
}
/// Error returned by [`Sender::send`][Sender::send].
@@ -251,12 +254,9 @@ struct Shared<T> {
/// Mask a position -> index
mask: usize,
- /// Tail of the queue
+ /// Tail of the queue. Includes the rx wait list.
tail: Mutex<Tail>,
- /// Stack of pending waiters
- wait_stack: AtomicPtr<WaitNode>,
-
/// Number of outstanding Sender handles
num_tx: AtomicUsize,
}
@@ -271,6 +271,9 @@ struct Tail {
/// True if the channel is closed
closed: bool,
+
+ /// Receivers waiting for a value
+ waiters: LinkedList<Waiter>,
}
/// Slot in the buffer
@@ -296,23 +299,59 @@ struct Slot<T> {
val: UnsafeCell<Option<T>>,
}
-/// Tracks a waiting receiver
-#[derive(Debug)]
-struct WaitNode {
- /// `true` if queued
- queued: AtomicBool,
+/// An entry in the wait queue
+struct Waiter {
+ /// True if queued
+ queued: bool,
+
+ /// Task waiting on the broadcast channel.
+ waker: Option<Waker>,
- /// Task to wake when a permit is made available.
- waker: AtomicWaker,
+ /// Intrusive linked-list pointers.
+ pointers: linked_list::Pointers<Waiter>,
- /// Next pointer in the stack of waiting senders.
- next: UnsafeCell<*const WaitNode>,
+ /// Should not be `Unpin`.
+ _p: PhantomPinned,
}
struct RecvGuard<'a, T> {
slot: RwLockReadGuard<'a, Slot<T>>,
}
+/// Receive a value future
+struct Recv<R, T>
+where
+ R: AsMut<Receiver<T>>,
+{
+ /// Receiver being waited on
+ receiver: R,
+
+ /// Entry in the waiter `LinkedList`
+ waiter: UnsafeCell<Waiter>,
+
+ _p: std::marker::PhantomData<T>,
+}
+
+/// `AsMut<T>` is not implemented for `T` (coherence). Explicitly implementing
+/// `AsMut` for `Receiver` would be included in the public API of the receiver
+/// type. Instead, `Borrow` is used internally to bridge the gap.
+struct Borrow<T>(T);
+
+impl<T> AsMut<Receiver<T>> for Borrow<Receiver<T>> {
+ fn as_mut(&mut self) -> &mut Receiver<T> {
+ &mut self.0
+ }
+}
+
+impl<'a, T> AsMut<Receiver<T>> for Borrow<&'a mut Receiver<T>> {
+ fn as_mut(&mut self) -> &mut Receiver<T> {
+ &mut *self.0
+ }
+}
+
+unsafe impl<R: AsMut<Receiver<T>> + Send, T: Send> Send for Recv<R, T> {}
+unsafe impl<R: AsMut<Receiver<T>> + Sync, T: Send> Sync for Recv<R, T> {}
+
/// Max number of receivers. Reserve space to lock.
const MAX_RECEIVERS: usize = usize::MAX >> 2;
@@ -386,19 +425,15 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
pos: 0,
rx_cnt: 1,
closed: false,
+ waiters: LinkedList::new(),
}),
- wait_stack: AtomicPtr::new(ptr::null_mut()),
num_tx: AtomicUsize::new(1),
});
let rx = Receiver {
shared: shared.clone(),
next: 0,
- wait: Arc::new(WaitNode {
- queued: AtomicBool::new(false),
- waker: AtomicWaker::new(),
- next: UnsafeCell::new(ptr::null()),
- }),
+ waiter: None,
};
let tx = Sender { shared };
@@ -508,11 +543,7 @@ impl<T> Sender<T> {
Receiver {
shared,
next,
- wait: Arc::new(WaitNode {
- queued: AtomicBool::new(false),
- waker: AtomicWaker::new(),
- next: UnsafeCell::new(ptr::null()),
- }),
+ waiter: None,
}
}
@@ -589,34 +620,31 @@ impl<T> Sender<T> {
slot.val.with_mut(|ptr| unsafe { *ptr = value });
}
- // Release the slot lock before the tail lock
+ // Release the slot lock before notifying the receivers.
drop(slot);
+ tail.notify_rx();
+
// Release the mutex. This must happen after the slot lock is released,
// otherwise the writer lock bit could be cleared while another thread
// is in the critical section.
drop(tail);
- // Notify waiting receivers
- self.notify_rx();
-
Ok(rem)
}
+}
- fn notify_rx(&self) {
- let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode;
-
- while !curr.is_null() {
- let waiter = unsafe { Arc::from_raw(curr) };
-
- // Update `curr` before toggling `queued` and waking
- curr = waiter.next.with(|ptr| unsafe { *ptr });
+impl Tail {
+ fn notify_rx(&mut self) {
+ while let Some(mut waiter) = self.waiters.pop_back() {
+ // Safety: `waiters` lock is still held.
+ let waiter = unsafe { waiter.as_mut() };
- // Unset queued
- waiter.queued.store(false, SeqCst);
+ assert!(waiter.queued);
+ waiter.queued = false;
- // Wake
- waiter.waker.wake();
+ let waker = waiter.waker.take().unwrap();
+ waker.wake();
}
}
}
@@ -640,15 +668,21 @@ impl<T> Drop for Sender<T> {
impl<T> Receiver<T> {
/// Locks the next value if there is one.
- fn recv_ref(&mut self) -> Result<RecvGuard<'_, T>, TryRecvError> {
+ fn recv_ref(
+ &mut self,
+ waiter: Option<(&UnsafeCell<Waiter>, &Waker)>,
+ ) -> Result<RecvGuard<'_, T>, TryRecvError> {
let idx = (self.next & self.shared.mask as u64) as usize;
// The slot holding the next value to read
let mut slot = self.shared.buffer[idx].read().unwrap();
if slot.pos != self.next {
- // The receiver has read all current values in the channel
- if slot.pos.wrapping_add(self.shared.buffer.len() as u64) == self.next {
+ let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);
+
+ // The receiver has read all current values in the channel and there
+ // is no waiter to register
+ if waiter.is_none() && next_pos == self.next {
return Err(TryRecvError::Empty);
}
@@ -661,35 +695,83 @@ impl<T> Receiver<T> {
// the slot lock.
drop(slot);
- let tail = self.shared.tail.lock().unwrap();
+ let mut tail = self.shared.tail.lock().unwrap();
// Acquire slot lock again
slot = self.shared.buffer[idx].read().unwrap();
- // `tail.pos` points to the slot that the **next** send writes to. If
- // the channel is closed, the previous slot is the oldest value.
- let mut adjust = 0;
- if tail.closed {
- adjust = 1
- }
- let next = tail
- .pos
- .wrapping_sub(self.shared.buffer.len() as u64 + adjust);
+ // Make sure the position did not change. This could happen in the
+ // unlikely event that the buffer is wrapped between dropping the
+ // read lock and acquiring the tail lock.
+ if slot.pos != self.next {
+ let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);
+
+ if next_pos == self.next {
+ // Store the waker
+ if let Some((waiter, waker)) = waiter {
+ // Safety: called while locked.
+ unsafe {
+ // Only queue if not already queued
+ waiter.with_mut(|ptr| {
+ // If there is no waker **or** if the currently
+ // stored waker references a **different** task,
+ // track the tasks' waker to be notified on
+ // receipt of a new value.
+ match (*ptr).waker {
+ Some(ref w) if w.will_wake(waker) => {}
+ _ => {
+ (*ptr).waker = Some(waker.clone());
+ }
+ }
+
+ if !(*ptr).queued {
+ (*ptr).queued = true;
+ tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr));
+ }
+ });
+ }
+ }
+
+ return Err(TryRecvError::Empty);
+ }
- let missed = next.wrapping_sub(self.next);
+ // At this point, the receiver has lagged behind the sender by
+ // more than the channel capacity. The receiver will attempt to
+ // catch up by skipping dropped messages and setting the
+ // internal cursor to the **oldest** message stored by the
+ // channel.
+ //
+ // However, finding the oldest position is a bit more
+ // complicated than `tail-position - buffer-size`. When
+ // the channel is closed, the tail position is incremented to
+ // signal a new `None` message, but `None` is not stored in the
+ // channel itself (see issue #2425 for why).
+ //
+ // To account for this, if the channel is closed, the tail
+ // position is decremented by `buffer-size + 1`.
+ let mut adjust = 0;
+ if tail.closed {
+ adjust = 1
+ }
+ let next = tail
+ .pos
+ .wrapping_sub(self.shared.buffer.len() as u64 + adjust);
- drop(tail);
+ let missed = next.wrapping_sub(self.next);
- // The receiver is slow but no values have been missed
- if missed == 0 {
- self.next = self.next.wrapping_add(1);
+ drop(tail);
- return Ok(RecvGuard { slot });
- }
+ // The receiver is slow but no values have been missed
+ if missed == 0 {
+ self.next = self.next.wrapping_add(1);
- self.next = next;
+ return Ok(RecvGuard { slot });
+ }
+
+ self.next = next;
- return Err(TryRecvError::Lagged(missed));
+ return Err(TryRecvError::Lagged(missed));
+ }
}
self.next = self.next.wrapping_add(1);
@@ -746,22 +828,59 @@ where
/// }
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
- let guard = self.recv_ref()?;
+ let guard = self.recv_ref(None)?;
guard.clone_value().ok_or(TryRecvError::Closed)
}
- #[doc(hidden)] // TODO: document
+ #[doc(hidden)]
+ #[deprecated(since = "0.2.21", note = "use async fn recv()")]
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
- if let Some(value) = ok_empty(self.try_recv())? {
- return Poll::Ready(Ok(value));
+ use Poll::{Pending, Ready};
+
+ // The borrow checker prohibits calling `self.poll_ref` while passing in
+ // a mutable ref to a field (as it should). To work around this,
+ // `waiter` is first *removed* from `self` then `poll_recv` is called.
+ //
+ // However, for safety, we must ensure that `waiter` is **not** dropped.
+ // It could be contained in the intrusive linked list. The `Receiver`
+ // drop implementation handles cleanup.
+ //
+ // The guard pattern is used to ensure that, on return, even due to
+ // panic, the waiter node is replaced on `self`.
+
+ struct Guard<'a, T> {
+ waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>,
+ receiver: &'a mut Receiver<T>,
}
- self.register_waker(cx.waker());
+ impl<'a, T> Drop for Guard<'a, T> {
+ fn drop(&mut self) {
+ self.receiver.waiter = self.waiter.take();
+ }
+ }
- if let Some(value) = ok_empty(self.try_recv())? {
- Poll::Ready(Ok(value))
- } else {
- Poll::Pending
+ let waiter = self.waiter.take().or_else(|| {
+ Some(Box::pin(UnsafeCell::new(Waiter {
+ queued: false,
+ waker: None,
+ pointers: linked_list::Pointers::new(),
+ _p: PhantomPinned,
+ })))
+ });
+
+ let guard = Guard {
+ waiter,
+ receiver: self,
+ };
+ let res = guard
+ .receiver
+ .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker())));
+
+ match res {
+ Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)),
+ Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)),
+ Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))),
+ Err(TryRecvError::Empty) => Pending,
}
}
@@ -830,44 +949,14 @@ where
/// assert_eq!(30, rx.recv().await.unwrap());
/// }
pub async fn recv(&mut self) -> Result<T, RecvError> {
- use crate::future::poll_fn;
-
- poll_fn(|cx| self.poll_recv(cx)).await
- }
-
- fn register_waker(&self, cx: &Waker) {
- self.wait.waker.register_by_ref(cx);
-
- if !self.wait.queued.load(SeqCst) {
- // Set `queued` before queuing.
- self.wait.queued.store(true, SeqCst);
-
- let mut curr = self.shared.wait_stack.load(SeqCst);
-
- // The ref count is decremented in `notify_rx` when all nodes are
- // removed from the waiter stack.
- let node = Arc::into_raw(self.wait.clone()) as *mut _;
-
- loop {
- // Safety: `queued == false` means the caller has exclusive
- // access to `self.wait.next`.
- self.wait.next.with_mut(|ptr| unsafe { *ptr = curr });
-
- let res = self
- .shared
- .wait_stack
- .compare_exchange(curr, node, SeqCst, SeqCst);
-
- match res {
- Ok(_) => return,
- Err(actual) => curr = actual,
- }
- }
- }
+ let fut = Recv::<_, T>::new(Borrow(self));
+ fut.await
}
}
#[cfg(feature = "stream")]
+#[doc(hidden)]
+#[deprecated(since = "0.2.21", note = "use `into_stream()`")]
impl<T> crate::stream::Stream for Receiver<T>
where
T: Clone,
@@ -878,6 +967,7 @@ where
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<T, RecvError>>> {
+ #[allow(deprecated)]
self.poll_recv(cx).map(|v| match v {
Ok(v) => Some(Ok(v)),
lag @ Err(RecvError::Lagged(_)) => Some(lag),
@@ -890,13 +980,30 @@ impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let mut tail = self.shared.tail.lock().unwrap();
+ if let Some(waiter) = &self.waiter {
+ // safety: tail lock is held
+ let queued = waiter.with(|ptr| unsafe { (*ptr).queued });
+
+ if queued {
+ // Remove the node
+ //
+ // safety: tail lock is held and the wait node is verified to be in
+ // the list.
+ unsafe {
+ waiter.with_mut(|ptr| {
+ tail.waiters.remove((&mut *ptr).into());
+ });
+ }
+ }
+ }
+
tail.rx_cnt -= 1;
let until = tail.pos;
drop(tail);
while self.next != until {
- match self.recv_ref() {
+ match self.recv_ref(None) {
Ok(_) => {}
// The channel is closed
Err(TryRecvError::Closed) => break,
@@ -909,18 +1016,170 @@ impl<T> Drop for Receiver<T> {
}
}
-impl<T> Drop for Shared<T> {
- fn drop(&mut self) {
- // Clear the wait stack
- let mut curr = self.wait_stack.with_mut(|ptr| *ptr as *const WaitNode);
+impl<R, T> Recv<R, T>
+where
+ R: AsMut<Receiver<T>>,
+{
+ fn new(receiver: R) -> Recv<R, T> {
+ Recv {
+ receiver,
+ waiter: UnsafeCell::new(Waiter {
+ queued: false,
+ waker: None,
+ pointers: linked_list::Pointers::new(),
+ _p: PhantomPinned,
+ }),
+ _p: std::marker::PhantomData,
+ }
+ }
- while !curr.is_null() {
- let waiter = unsafe { Arc::from_raw(curr) };
- curr = waiter.next.with(|ptr| unsafe { *ptr });
+ /// A custom `project` implementation is used in place of `pin-project-lite`
+ /// as a custom drop implementation is needed.
+ fn project(self: Pin<&mut Self>) -> (&mut Receiver<T>, &UnsafeCell<Waiter>) {
+ unsafe {
+ // Safety: Receiver is Unpin
+ is_unpin::<&mut Receiver<T>>();
+
+ let me = self.get_unchecked_mut();
+ (me.receiver.as_mut(), &me.waiter)
}
}
}
+impl<R, T> Future for Recv<R, T>
+where
+ R: AsMut<Receiver<T>>,
+ T: Clone,
+{
+ type Output = Result<T, RecvError>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
+ let (receiver, waiter) = self.project();
+
+ let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) {
+ Ok(value) => value,
+ Err(TryRecvError::Empty) => return Poll::Pending,
+ Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))),
+ Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)),
+ };
+
+ Poll::Ready(guard.clone_value().ok_or(RecvError::Closed))
+ }
+}
+
+cfg_stream! {
+ use futures_core::Stream;
+
+ impl<T: Clone> Receiver<T> {
+ /// Convert the receiver into a `Stream`.
+ ///
+ /// The conversion allows using `Receiver` with APIs that require stream
+ /// values.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::stream::StreamExt;
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, rx) = broadcast::channel(128);
+ ///
+ /// tokio::spawn(async move {
+ /// for i in 0..10_i32 {
+ /// tx.send(i).unwrap();
+ /// }
+ /// });
+ ///
+ /// // Streams must be pinned to iterate.
+ /// tokio::pin! {
+ /// let stream = rx
+ /// .into_stream()
+ /// .filter(Result::is_ok)
+ /// .map(Result::unwrap)
+ /// .filter(|v| v % 2 == 0)
+ /// .map(|v| v + 1);
+ /// }
+ ///
+ /// while let Some(i) = stream.next().await {
+ /// println!("{}", i);
+ /// }
+ /// }
+ /// ```
+ pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> {
+ Recv::new(Borrow(self))
+ }
+ }
+
+ impl<R, T: Clone> Stream for Recv<R, T>
+ where
+ R: AsMut<Receiver<T>>,
+ T: Clone,
+ {
+ type Item = Result<T, RecvError>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let (receiver, waiter) = self.project();
+
+ let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) {
+ Ok(value) => value,
+ Err(TryRecvError::Empty) => return Poll::Pending,
+ Err(TryRecvError::Lagged(n)) => return Poll::Ready(Some(Err(RecvError::Lagged(n)))),
+ Err(TryRecvError::Closed) => return Poll::Ready(None),
+ };
+
+ Poll::Ready(guard.clone_value().map(Ok))
+ }
+ }
+}
+
+impl<R, T> Drop for Recv<R, T>
+where
+ R: AsMut<Receiver<T>>,
+{
+ fn drop(&mut self) {
+ // Acquire the tail lock. This is required for safety before accessing
+ // the waiter node.
+ let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap();
+
+ // safety: tail lock is held
+ let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });
+
+ if queued {
+ // Remove the node
+ //
+ // safety: tail lock is held and the wait node is verified to be in
+ // the list.
+ unsafe {
+ self.waiter.with_mut(|ptr| {
+ tail.waiters.remove((&mut *ptr).into());
+ });
+ }
+ }
+ }
+}
+
+/// # Safety
+///
+/// `Waiter` is forced to be !Unpin.
+unsafe impl linked_list::Link for Waiter {
+ type Handle = NonNull<Waiter>;
+ type Target = Waiter;
+
+ fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
+ *handle
+ }
+
+ unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
+ ptr
+ }
+
+ unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
+ NonNull::from(&mut target.as_mut().pointers)
+ }
+}
+
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "broadcast::Sender")
@@ -952,15 +1211,6 @@ impl<'a, T> Drop for RecvGuard<'a, T> {
}
}
-fn ok_empty<T>(res: Result<T, TryRecvError>) -> Result<Option<T>, RecvError> {
- match res {
- Ok(value) => Ok(Some(value)),
- Err(TryRecvError::Empty) => Ok(None),
- Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)),
- Err(TryRecvError::Closed) => Err(RecvError::Closed),
- }
-}
-
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@@ -983,3 +1233,5 @@ impl fmt::Display for TryRecvError {
}
impl std::error::Error for TryRecvError {}
+
+fn is_unpin<T: Unpin>() {}