diff options
author | Carl Lerche <me@carllerche.com> | 2020-05-06 07:37:44 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-06 07:37:44 -0700 |
commit | cc8a6625982b5fc0694d05b4e9fb7d6a592702a1 (patch) | |
tree | 5f946da38fb9931f08998f0eb472828757ad3a8e /tokio/src/sync/broadcast.rs | |
parent | 264ae3bdb22004609de45b67e2890081bb47e5b2 (diff) |
sync: simplify the broadcast channel (#2467)
Replace an ad hoc read/write lock with RwLock. Use
The parking_lot RwLock when possible.
Diffstat (limited to 'tokio/src/sync/broadcast.rs')
-rw-r--r-- | tokio/src/sync/broadcast.rs | 215 |
1 files changed, 57 insertions, 158 deletions
diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index abc4974a..9873dcb7 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -110,11 +110,10 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{spin_loop_hint, AtomicBool, AtomicPtr, AtomicUsize}; -use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::loom::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize}; +use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; use std::fmt; -use std::mem; use std::ptr; use std::sync::atomic::Ordering::SeqCst; use std::task::{Context, Poll, Waker}; @@ -247,7 +246,7 @@ pub enum TryRecvError { /// Data shared between senders and receivers struct Shared<T> { /// slots in the channel - buffer: Box<[Slot<T>]>, + buffer: Box<[RwLock<Slot<T>>]>, /// Mask a position -> index mask: usize, @@ -255,9 +254,6 @@ struct Shared<T> { /// Tail of the queue tail: Mutex<Tail>, - /// Notifies a sender that the slot is unlocked - condvar: Condvar, - /// Stack of pending waiters wait_stack: AtomicPtr<WaitNode>, @@ -282,23 +278,21 @@ struct Slot<T> { /// Remaining number of receivers that are expected to see this value. /// /// When this goes to zero, the value is released. + /// + /// An atomic is used as it is mutated concurrently with the slot read lock + /// acquired. rem: AtomicUsize, - /// Used to lock the `write` field. - lock: AtomicUsize, - - /// The value being broadcast - /// - /// Synchronized by `state` - write: Write<T>, -} + /// Uniquely identifies the `send` stored in the slot + pos: u64, -/// A write in the buffer -struct Write<T> { - /// Uniquely identifies this write - pos: UnsafeCell<u64>, + /// True signals the channel is closed. + closed: bool, - /// The written value + /// The value being broadcast. + /// + /// The value is set by `send` when the write lock is held. When a reader + /// drops, `rem` is decremented. When it hits zero, the value is dropped. val: UnsafeCell<Option<T>>, } @@ -316,16 +310,11 @@ struct WaitNode { } struct RecvGuard<'a, T> { - slot: &'a Slot<T>, - tail: &'a Mutex<Tail>, - condvar: &'a Condvar, + slot: RwLockReadGuard<'a, Slot<T>>, } /// Max number of receivers. Reserve space to lock. const MAX_RECEIVERS: usize = usize::MAX >> 2; -const CLOSED: usize = 1; -const WRITER: usize = 2; -const READER: usize = 4; /// Create a bounded, multi-producer, multi-consumer channel where each sent /// value is broadcasted to all active receivers. @@ -382,14 +371,12 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { let mut buffer = Vec::with_capacity(capacity); for i in 0..capacity { - buffer.push(Slot { + buffer.push(RwLock::new(Slot { rem: AtomicUsize::new(0), - lock: AtomicUsize::new(0), - write: Write { - pos: UnsafeCell::new((i as u64).wrapping_sub(capacity as u64)), - val: UnsafeCell::new(None), - }, - }); + pos: (i as u64).wrapping_sub(capacity as u64), + closed: false, + val: UnsafeCell::new(None), + })); } let shared = Arc::new(Shared { @@ -400,7 +387,6 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { rx_cnt: 1, closed: false, }), - condvar: Condvar::new(), wait_stack: AtomicPtr::new(ptr::null_mut()), num_tx: AtomicUsize::new(1), }); @@ -587,46 +573,25 @@ impl<T> Sender<T> { tail.pos = tail.pos.wrapping_add(1); // Get the slot - let slot = &self.shared.buffer[idx]; - - // Acquire the write lock - let mut prev = slot.lock.fetch_or(WRITER, SeqCst); + let mut slot = self.shared.buffer[idx].write().unwrap(); - while prev & !WRITER != 0 { - // Concurrent readers, we must go to sleep - tail = self.shared.condvar.wait(tail).unwrap(); - - prev = slot.lock.load(SeqCst); - - if prev & WRITER == 0 { - // The writer lock bit was cleared while this thread was - // sleeping. This can only happen if a newer write happened on - // this slot by another thread. Bail early as an optimization, - // there is nothing left to do. - return Ok(rem); - } - } - - if tail.pos.wrapping_sub(pos) > self.shared.buffer.len() as u64 { - // There is a newer pending write to the same slot. - return Ok(rem); - } - - // Slot lock acquired - slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos }); + // Track the position + slot.pos = pos; // Set remaining receivers - slot.rem.store(rem, SeqCst); + slot.rem.with_mut(|v| *v = rem); // Set the closed bit if the value is `None`; otherwise write the value if value.is_none() { tail.closed = true; - slot.lock.store(CLOSED, SeqCst); + slot.closed = true; } else { - slot.write.val.with_mut(|ptr| unsafe { *ptr = value }); - slot.lock.store(0, SeqCst); + slot.val.with_mut(|ptr| unsafe { *ptr = value }); } + // Release the slot lock before the tail lock + drop(slot); + // Release the mutex. This must happen after the slot lock is released, // otherwise the writer lock bit could be cleared while another thread // is in the critical section. @@ -675,42 +640,32 @@ impl<T> Drop for Sender<T> { impl<T> Receiver<T> { /// Locks the next value if there is one. - /// - /// The caller is responsible for unlocking - fn recv_ref(&mut self, spin: bool) -> Result<RecvGuard<'_, T>, TryRecvError> { + fn recv_ref(&mut self) -> Result<RecvGuard<'_, T>, TryRecvError> { let idx = (self.next & self.shared.mask as u64) as usize; // The slot holding the next value to read - let slot = &self.shared.buffer[idx]; - - // Lock the slot - if !slot.try_rx_lock() { - if spin { - while !slot.try_rx_lock() { - spin_loop_hint(); - } - } else { - return Err(TryRecvError::Empty); - } - } - - let guard = RecvGuard { - slot, - tail: &self.shared.tail, - condvar: &self.shared.condvar, - }; - - if guard.pos() != self.next { - let pos = guard.pos(); + let mut slot = self.shared.buffer[idx].read().unwrap(); + if slot.pos != self.next { // The receiver has read all current values in the channel - if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { - guard.drop_no_rem_dec(); + if slot.pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { return Err(TryRecvError::Empty); } + // Release the `slot` lock before attempting to acquire the `tail` + // lock. This is required because `send2` acquires the tail lock + // first followed by the slot lock. Acquiring the locks in reverse + // order here would result in a potential deadlock: `recv_ref` + // acquires the `slot` lock and attempts to acquire the `tail` lock + // while `send2` acquired the `tail` lock and attempts to acquire + // the slot lock. + drop(slot); + let tail = self.shared.tail.lock().unwrap(); + // Acquire slot lock again + slot = self.shared.buffer[idx].read().unwrap(); + // `tail.pos` points to the slot that the **next** send writes to. If // the channel is closed, the previous slot is the oldest value. let mut adjust = 0; @@ -728,10 +683,10 @@ impl<T> Receiver<T> { // The receiver is slow but no values have been missed if missed == 0 { self.next = self.next.wrapping_add(1); - return Ok(guard); + + return Ok(RecvGuard { slot }); } - guard.drop_no_rem_dec(); self.next = next; return Err(TryRecvError::Lagged(missed)); @@ -739,17 +694,11 @@ impl<T> Receiver<T> { self.next = self.next.wrapping_add(1); - // If the `CLOSED` bit it set on the slot, the channel is closed - // - // `try_rx_lock` could check for this and bail early. If it's return - // value was changed to represent the state of the lock, it could - // match on being closed, empty, or available for reading. - if slot.lock.load(SeqCst) & CLOSED == CLOSED { - guard.drop_no_rem_dec(); + if slot.closed { return Err(TryRecvError::Closed); } - Ok(guard) + Ok(RecvGuard { slot }) } } @@ -797,7 +746,7 @@ where /// } /// ``` pub fn try_recv(&mut self) -> Result<T, TryRecvError> { - let guard = self.recv_ref(false)?; + let guard = self.recv_ref()?; guard.clone_value().ok_or(TryRecvError::Closed) } @@ -947,7 +896,7 @@ impl<T> Drop for Receiver<T> { drop(tail); while self.next != until { - match self.recv_ref(true) { + match self.recv_ref() { Ok(_) => {} // The channel is closed Err(TryRecvError::Closed) => break, @@ -984,72 +933,22 @@ impl<T> fmt::Debug for Receiver<T> { } } -impl<T> Slot<T> { - /// Tries to lock the slot for a receiver. If `false`, then a sender holds the - /// lock and the calling task will be notified once the sender has released - /// the lock. - fn try_rx_lock(&self) -> bool { - let mut curr = self.lock.load(SeqCst); - - loop { - if curr & WRITER == WRITER { - // Locked by sender - return false; - } - - // Only increment (by `READER`) if the `WRITER` bit is not set. - let res = self - .lock - .compare_exchange(curr, curr + READER, SeqCst, SeqCst); - - match res { - Ok(_) => return true, - Err(actual) => curr = actual, - } - } - } - - fn rx_unlock(&self, tail: &Mutex<Tail>, condvar: &Condvar, rem_dec: bool) { - if rem_dec { - // Decrement the remaining counter - if 1 == self.rem.fetch_sub(1, SeqCst) { - // Last receiver, drop the value - self.write.val.with_mut(|ptr| unsafe { *ptr = None }); - } - } - - if WRITER == self.lock.fetch_sub(READER, SeqCst) - READER { - // First acquire the lock to make sure our sender is waiting on the - // condition variable, otherwise the notification could be lost. - mem::drop(tail.lock().unwrap()); - // Wake up senders - condvar.notify_all(); - } - } -} - impl<'a, T> RecvGuard<'a, T> { - fn pos(&self) -> u64 { - self.slot.write.pos.with(|ptr| unsafe { *ptr }) - } - fn clone_value(&self) -> Option<T> where T: Clone, { - self.slot.write.val.with(|ptr| unsafe { (*ptr).clone() }) - } - - fn drop_no_rem_dec(self) { - self.slot.rx_unlock(self.tail, self.condvar, false); - - mem::forget(self); + self.slot.val.with(|ptr| unsafe { (*ptr).clone() }) } } impl<'a, T> Drop for RecvGuard<'a, T> { fn drop(&mut self) { - self.slot.rx_unlock(self.tail, self.condvar, true) + // Decrement the remaining counter + if 1 == self.slot.rem.fetch_sub(1, SeqCst) { + // Safety: Last receiver, drop the value + self.slot.val.with_mut(|ptr| unsafe { *ptr = None }); + } } } |