diff options
author | Carl Lerche <me@carllerche.com> | 2020-05-12 15:09:43 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-12 15:09:43 -0700 |
commit | fb7dfcf4322b5e60604815aea91266b88f0b7823 (patch) | |
tree | aeba04a918be8a00eb09f6001a4f7946bd188c66 /tokio/src/sync/broadcast.rs | |
parent | a32f918671ef641affbfcc4d4005ab738da795df (diff) |
sync: use intrusive list strategy for broadcast (#2509)
Previously, in the broadcast channel, receiver wakers were passed to the
sender via an atomic stack with allocated nodes. When a message was
sent, the stack was drained. This caused a problem when many receivers
pushed a waiter node then dropped. The waiter node remained indefinitely
in cases where no values were sent.
This patch switches broadcast to use the intrusive linked-list waiter
strategy used by `Notify` and `Semaphore.
Diffstat (limited to 'tokio/src/sync/broadcast.rs')
-rw-r--r-- | tokio/src/sync/broadcast.rs | 506 |
1 files changed, 379 insertions, 127 deletions
diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 9873dcb7..0c8716f7 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -109,12 +109,15 @@ //! } use crate::loom::cell::UnsafeCell; -use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize}; +use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use crate::util::linked_list::{self, LinkedList}; use std::fmt; -use std::ptr; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; use std::sync::atomic::Ordering::SeqCst; use std::task::{Context, Poll, Waker}; use std::usize; @@ -192,8 +195,8 @@ pub struct Receiver<T> { /// Next position to read from next: u64, - /// Waiter state - wait: Arc<WaitNode>, + /// Used to support the deprecated `poll_recv` fn + waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>, } /// Error returned by [`Sender::send`][Sender::send]. @@ -251,12 +254,9 @@ struct Shared<T> { /// Mask a position -> index mask: usize, - /// Tail of the queue + /// Tail of the queue. Includes the rx wait list. tail: Mutex<Tail>, - /// Stack of pending waiters - wait_stack: AtomicPtr<WaitNode>, - /// Number of outstanding Sender handles num_tx: AtomicUsize, } @@ -271,6 +271,9 @@ struct Tail { /// True if the channel is closed closed: bool, + + /// Receivers waiting for a value + waiters: LinkedList<Waiter>, } /// Slot in the buffer @@ -296,23 +299,59 @@ struct Slot<T> { val: UnsafeCell<Option<T>>, } -/// Tracks a waiting receiver -#[derive(Debug)] -struct WaitNode { - /// `true` if queued - queued: AtomicBool, +/// An entry in the wait queue +struct Waiter { + /// True if queued + queued: bool, + + /// Task waiting on the broadcast channel. + waker: Option<Waker>, - /// Task to wake when a permit is made available. - waker: AtomicWaker, + /// Intrusive linked-list pointers. + pointers: linked_list::Pointers<Waiter>, - /// Next pointer in the stack of waiting senders. - next: UnsafeCell<*const WaitNode>, + /// Should not be `Unpin`. + _p: PhantomPinned, } struct RecvGuard<'a, T> { slot: RwLockReadGuard<'a, Slot<T>>, } +/// Receive a value future +struct Recv<R, T> +where + R: AsMut<Receiver<T>>, +{ + /// Receiver being waited on + receiver: R, + + /// Entry in the waiter `LinkedList` + waiter: UnsafeCell<Waiter>, + + _p: std::marker::PhantomData<T>, +} + +/// `AsMut<T>` is not implemented for `T` (coherence). Explicitly implementing +/// `AsMut` for `Receiver` would be included in the public API of the receiver +/// type. Instead, `Borrow` is used internally to bridge the gap. +struct Borrow<T>(T); + +impl<T> AsMut<Receiver<T>> for Borrow<Receiver<T>> { + fn as_mut(&mut self) -> &mut Receiver<T> { + &mut self.0 + } +} + +impl<'a, T> AsMut<Receiver<T>> for Borrow<&'a mut Receiver<T>> { + fn as_mut(&mut self) -> &mut Receiver<T> { + &mut *self.0 + } +} + +unsafe impl<R: AsMut<Receiver<T>> + Send, T: Send> Send for Recv<R, T> {} +unsafe impl<R: AsMut<Receiver<T>> + Sync, T: Send> Sync for Recv<R, T> {} + /// Max number of receivers. Reserve space to lock. const MAX_RECEIVERS: usize = usize::MAX >> 2; @@ -386,19 +425,15 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { pos: 0, rx_cnt: 1, closed: false, + waiters: LinkedList::new(), }), - wait_stack: AtomicPtr::new(ptr::null_mut()), num_tx: AtomicUsize::new(1), }); let rx = Receiver { shared: shared.clone(), next: 0, - wait: Arc::new(WaitNode { - queued: AtomicBool::new(false), - waker: AtomicWaker::new(), - next: UnsafeCell::new(ptr::null()), - }), + waiter: None, }; let tx = Sender { shared }; @@ -508,11 +543,7 @@ impl<T> Sender<T> { Receiver { shared, next, - wait: Arc::new(WaitNode { - queued: AtomicBool::new(false), - waker: AtomicWaker::new(), - next: UnsafeCell::new(ptr::null()), - }), + waiter: None, } } @@ -589,34 +620,31 @@ impl<T> Sender<T> { slot.val.with_mut(|ptr| unsafe { *ptr = value }); } - // Release the slot lock before the tail lock + // Release the slot lock before notifying the receivers. drop(slot); + tail.notify_rx(); + // Release the mutex. This must happen after the slot lock is released, // otherwise the writer lock bit could be cleared while another thread // is in the critical section. drop(tail); - // Notify waiting receivers - self.notify_rx(); - Ok(rem) } +} - fn notify_rx(&self) { - let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode; - - while !curr.is_null() { - let waiter = unsafe { Arc::from_raw(curr) }; - - // Update `curr` before toggling `queued` and waking - curr = waiter.next.with(|ptr| unsafe { *ptr }); +impl Tail { + fn notify_rx(&mut self) { + while let Some(mut waiter) = self.waiters.pop_back() { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; - // Unset queued - waiter.queued.store(false, SeqCst); + assert!(waiter.queued); + waiter.queued = false; - // Wake - waiter.waker.wake(); + let waker = waiter.waker.take().unwrap(); + waker.wake(); } } } @@ -640,15 +668,21 @@ impl<T> Drop for Sender<T> { impl<T> Receiver<T> { /// Locks the next value if there is one. - fn recv_ref(&mut self) -> Result<RecvGuard<'_, T>, TryRecvError> { + fn recv_ref( + &mut self, + waiter: Option<(&UnsafeCell<Waiter>, &Waker)>, + ) -> Result<RecvGuard<'_, T>, TryRecvError> { let idx = (self.next & self.shared.mask as u64) as usize; // The slot holding the next value to read let mut slot = self.shared.buffer[idx].read().unwrap(); if slot.pos != self.next { - // The receiver has read all current values in the channel - if slot.pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + // The receiver has read all current values in the channel and there + // is no waiter to register + if waiter.is_none() && next_pos == self.next { return Err(TryRecvError::Empty); } @@ -661,35 +695,83 @@ impl<T> Receiver<T> { // the slot lock. drop(slot); - let tail = self.shared.tail.lock().unwrap(); + let mut tail = self.shared.tail.lock().unwrap(); // Acquire slot lock again slot = self.shared.buffer[idx].read().unwrap(); - // `tail.pos` points to the slot that the **next** send writes to. If - // the channel is closed, the previous slot is the oldest value. - let mut adjust = 0; - if tail.closed { - adjust = 1 - } - let next = tail - .pos - .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + // Make sure the position did not change. This could happen in the + // unlikely event that the buffer is wrapped between dropping the + // read lock and acquiring the tail lock. + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + if next_pos == self.next { + // Store the waker + if let Some((waiter, waker)) = waiter { + // Safety: called while locked. + unsafe { + // Only queue if not already queued + waiter.with_mut(|ptr| { + // If there is no waker **or** if the currently + // stored waker references a **different** task, + // track the tasks' waker to be notified on + // receipt of a new value. + match (*ptr).waker { + Some(ref w) if w.will_wake(waker) => {} + _ => { + (*ptr).waker = Some(waker.clone()); + } + } + + if !(*ptr).queued { + (*ptr).queued = true; + tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); + } + }); + } + } + + return Err(TryRecvError::Empty); + } - let missed = next.wrapping_sub(self.next); + // At this point, the receiver has lagged behind the sender by + // more than the channel capacity. The receiver will attempt to + // catch up by skipping dropped messages and setting the + // internal cursor to the **oldest** message stored by the + // channel. + // + // However, finding the oldest position is a bit more + // complicated than `tail-position - buffer-size`. When + // the channel is closed, the tail position is incremented to + // signal a new `None` message, but `None` is not stored in the + // channel itself (see issue #2425 for why). + // + // To account for this, if the channel is closed, the tail + // position is decremented by `buffer-size + 1`. + let mut adjust = 0; + if tail.closed { + adjust = 1 + } + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); - drop(tail); + let missed = next.wrapping_sub(self.next); - // The receiver is slow but no values have been missed - if missed == 0 { - self.next = self.next.wrapping_add(1); + drop(tail); - return Ok(RecvGuard { slot }); - } + // The receiver is slow but no values have been missed + if missed == 0 { + self.next = self.next.wrapping_add(1); - self.next = next; + return Ok(RecvGuard { slot }); + } + + self.next = next; - return Err(TryRecvError::Lagged(missed)); + return Err(TryRecvError::Lagged(missed)); + } } self.next = self.next.wrapping_add(1); @@ -746,22 +828,59 @@ where /// } /// ``` pub fn try_recv(&mut self) -> Result<T, TryRecvError> { - let guard = self.recv_ref()?; + let guard = self.recv_ref(None)?; guard.clone_value().ok_or(TryRecvError::Closed) } - #[doc(hidden)] // TODO: document + #[doc(hidden)] + #[deprecated(since = "0.2.21", note = "use async fn recv()")] pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { - if let Some(value) = ok_empty(self.try_recv())? { - return Poll::Ready(Ok(value)); + use Poll::{Pending, Ready}; + + // The borrow checker prohibits calling `self.poll_ref` while passing in + // a mutable ref to a field (as it should). To work around this, + // `waiter` is first *removed* from `self` then `poll_recv` is called. + // + // However, for safety, we must ensure that `waiter` is **not** dropped. + // It could be contained in the intrusive linked list. The `Receiver` + // drop implementation handles cleanup. + // + // The guard pattern is used to ensure that, on return, even due to + // panic, the waiter node is replaced on `self`. + + struct Guard<'a, T> { + waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>, + receiver: &'a mut Receiver<T>, } - self.register_waker(cx.waker()); + impl<'a, T> Drop for Guard<'a, T> { + fn drop(&mut self) { + self.receiver.waiter = self.waiter.take(); + } + } - if let Some(value) = ok_empty(self.try_recv())? { - Poll::Ready(Ok(value)) - } else { - Poll::Pending + let waiter = self.waiter.take().or_else(|| { + Some(Box::pin(UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }))) + }); + + let guard = Guard { + waiter, + receiver: self, + }; + let res = guard + .receiver + .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker()))); + + match res { + Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)), + Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)), + Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Empty) => Pending, } } @@ -830,44 +949,14 @@ where /// assert_eq!(30, rx.recv().await.unwrap()); /// } pub async fn recv(&mut self) -> Result<T, RecvError> { - use crate::future::poll_fn; - - poll_fn(|cx| self.poll_recv(cx)).await - } - - fn register_waker(&self, cx: &Waker) { - self.wait.waker.register_by_ref(cx); - - if !self.wait.queued.load(SeqCst) { - // Set `queued` before queuing. - self.wait.queued.store(true, SeqCst); - - let mut curr = self.shared.wait_stack.load(SeqCst); - - // The ref count is decremented in `notify_rx` when all nodes are - // removed from the waiter stack. - let node = Arc::into_raw(self.wait.clone()) as *mut _; - - loop { - // Safety: `queued == false` means the caller has exclusive - // access to `self.wait.next`. - self.wait.next.with_mut(|ptr| unsafe { *ptr = curr }); - - let res = self - .shared - .wait_stack - .compare_exchange(curr, node, SeqCst, SeqCst); - - match res { - Ok(_) => return, - Err(actual) => curr = actual, - } - } - } + let fut = Recv::<_, T>::new(Borrow(self)); + fut.await } } #[cfg(feature = "stream")] +#[doc(hidden)] +#[deprecated(since = "0.2.21", note = "use `into_stream()`")] impl<T> crate::stream::Stream for Receiver<T> where T: Clone, @@ -878,6 +967,7 @@ where mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<T, RecvError>>> { + #[allow(deprecated)] self.poll_recv(cx).map(|v| match v { Ok(v) => Some(Ok(v)), lag @ Err(RecvError::Lagged(_)) => Some(lag), @@ -890,13 +980,30 @@ impl<T> Drop for Receiver<T> { fn drop(&mut self) { let mut tail = self.shared.tail.lock().unwrap(); + if let Some(waiter) = &self.waiter { + // safety: tail lock is held + let queued = waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } + } + } + tail.rx_cnt -= 1; let until = tail.pos; drop(tail); while self.next != until { - match self.recv_ref() { + match self.recv_ref(None) { Ok(_) => {} // The channel is closed Err(TryRecvError::Closed) => break, @@ -909,18 +1016,170 @@ impl<T> Drop for Receiver<T> { } } -impl<T> Drop for Shared<T> { - fn drop(&mut self) { - // Clear the wait stack - let mut curr = self.wait_stack.with_mut(|ptr| *ptr as *const WaitNode); +impl<R, T> Recv<R, T> +where + R: AsMut<Receiver<T>>, +{ + fn new(receiver: R) -> Recv<R, T> { + Recv { + receiver, + waiter: UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }), + _p: std::marker::PhantomData, + } + } - while !curr.is_null() { - let waiter = unsafe { Arc::from_raw(curr) }; - curr = waiter.next.with(|ptr| unsafe { *ptr }); + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&mut Receiver<T>, &UnsafeCell<Waiter>) { + unsafe { + // Safety: Receiver is Unpin + is_unpin::<&mut Receiver<T>>(); + + let me = self.get_unchecked_mut(); + (me.receiver.as_mut(), &me.waiter) } } } +impl<R, T> Future for Recv<R, T> +where + R: AsMut<Receiver<T>>, + T: Clone, +{ + type Output = Result<T, RecvError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + let (receiver, waiter) = self.project(); + + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), + }; + + Poll::Ready(guard.clone_value().ok_or(RecvError::Closed)) + } +} + +cfg_stream! { + use futures_core::Stream; + + impl<T: Clone> Receiver<T> { + /// Convert the receiver into a `Stream`. + /// + /// The conversion allows using `Receiver` with APIs that require stream + /// values. + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::StreamExt; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = broadcast::channel(128); + /// + /// tokio::spawn(async move { + /// for i in 0..10_i32 { + /// tx.send(i).unwrap(); + /// } + /// }); + /// + /// // Streams must be pinned to iterate. + /// tokio::pin! { + /// let stream = rx + /// .into_stream() + /// .filter(Result::is_ok) + /// .map(Result::unwrap) + /// .filter(|v| v % 2 == 0) + /// .map(|v| v + 1); + /// } + /// + /// while let Some(i) = stream.next().await { + /// println!("{}", i); + /// } + /// } + /// ``` + pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> { + Recv::new(Borrow(self)) + } + } + + impl<R, T: Clone> Stream for Recv<R, T> + where + R: AsMut<Receiver<T>>, + T: Clone, + { + type Item = Result<T, RecvError>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let (receiver, waiter) = self.project(); + + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Some(Err(RecvError::Lagged(n)))), + Err(TryRecvError::Closed) => return Poll::Ready(None), + }; + + Poll::Ready(guard.clone_value().map(Ok)) + } + } +} + +impl<R, T> Drop for Recv<R, T> +where + R: AsMut<Receiver<T>>, +{ + fn drop(&mut self) { + // Acquire the tail lock. This is required for safety before accessing + // the waiter node. + let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap(); + + // safety: tail lock is held + let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } +} + impl<T> fmt::Debug for Sender<T> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "broadcast::Sender") @@ -952,15 +1211,6 @@ impl<'a, T> Drop for RecvGuard<'a, T> { } } -fn ok_empty<T>(res: Result<T, TryRecvError>) -> Result<Option<T>, RecvError> { - match res { - Ok(value) => Ok(Some(value)), - Err(TryRecvError::Empty) => Ok(None), - Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)), - Err(TryRecvError::Closed) => Err(RecvError::Closed), - } -} - impl fmt::Display for RecvError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -983,3 +1233,5 @@ impl fmt::Display for TryRecvError { } impl std::error::Error for TryRecvError {} + +fn is_unpin<T: Unpin>() {} |