summaryrefslogtreecommitdiffstats
path: root/tokio/src/sync/broadcast.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tokio/src/sync/broadcast.rs')
-rw-r--r--tokio/src/sync/broadcast.rs90
1 files changed, 65 insertions, 25 deletions
diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs
index 05a58070..abc4974a 100644
--- a/tokio/src/sync/broadcast.rs
+++ b/tokio/src/sync/broadcast.rs
@@ -272,6 +272,9 @@ struct Tail {
/// Number of active receivers
rx_cnt: usize,
+
+ /// True if the channel is closed
+ closed: bool,
}
/// Slot in the buffer
@@ -319,7 +322,10 @@ struct RecvGuard<'a, T> {
}
/// Max number of receivers. Reserve space to lock.
-const MAX_RECEIVERS: usize = usize::MAX >> 1;
+const MAX_RECEIVERS: usize = usize::MAX >> 2;
+const CLOSED: usize = 1;
+const WRITER: usize = 2;
+const READER: usize = 4;
/// Create a bounded, multi-producer, multi-consumer channel where each sent
/// value is broadcasted to all active receivers.
@@ -389,7 +395,11 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared {
buffer: buffer.into_boxed_slice(),
mask: capacity - 1,
- tail: Mutex::new(Tail { pos: 0, rx_cnt: 1 }),
+ tail: Mutex::new(Tail {
+ pos: 0,
+ rx_cnt: 1,
+ closed: false,
+ }),
condvar: Condvar::new(),
wait_stack: AtomicPtr::new(ptr::null_mut()),
num_tx: AtomicUsize::new(1),
@@ -580,15 +590,15 @@ impl<T> Sender<T> {
let slot = &self.shared.buffer[idx];
// Acquire the write lock
- let mut prev = slot.lock.fetch_or(1, SeqCst);
+ let mut prev = slot.lock.fetch_or(WRITER, SeqCst);
- while prev & !1 != 0 {
+ while prev & !WRITER != 0 {
// Concurrent readers, we must go to sleep
tail = self.shared.condvar.wait(tail).unwrap();
prev = slot.lock.load(SeqCst);
- if prev & 1 == 0 {
+ if prev & WRITER == 0 {
// The writer lock bit was cleared while this thread was
// sleeping. This can only happen if a newer write happened on
// this slot by another thread. Bail early as an optimization,
@@ -604,13 +614,18 @@ impl<T> Sender<T> {
// Slot lock acquired
slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos });
- slot.write.val.with_mut(|ptr| unsafe { *ptr = value });
// Set remaining receivers
slot.rem.store(rem, SeqCst);
- // Release the slot lock
- slot.lock.store(0, SeqCst);
+ // Set the closed bit if the value is `None`; otherwise write the value
+ if value.is_none() {
+ tail.closed = true;
+ slot.lock.store(CLOSED, SeqCst);
+ } else {
+ slot.write.val.with_mut(|ptr| unsafe { *ptr = value });
+ slot.lock.store(0, SeqCst);
+ }
// Release the mutex. This must happen after the slot lock is released,
// otherwise the writer lock bit could be cleared while another thread
@@ -688,28 +703,52 @@ impl<T> Receiver<T> {
if guard.pos() != self.next {
let pos = guard.pos();
- guard.drop_no_rem_dec();
-
+ // The receiver has read all current values in the channel
if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next {
+ guard.drop_no_rem_dec();
return Err(TryRecvError::Empty);
- } else {
- let tail = self.shared.tail.lock().unwrap();
+ }
- // `tail.pos` points to the slot the **next** send writes to.
- // Because a receiver is lagging, this slot also holds the
- // oldest value. To make the positions match, we subtract the
- // capacity.
- let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64);
- let missed = next.wrapping_sub(self.next);
+ let tail = self.shared.tail.lock().unwrap();
- self.next = next;
+ // `tail.pos` points to the slot that the **next** send writes to. If
+ // the channel is closed, the previous slot is the oldest value.
+ let mut adjust = 0;
+ if tail.closed {
+ adjust = 1
+ }
+ let next = tail
+ .pos
+ .wrapping_sub(self.shared.buffer.len() as u64 + adjust);
- return Err(TryRecvError::Lagged(missed));
+ let missed = next.wrapping_sub(self.next);
+
+ drop(tail);
+
+ // The receiver is slow but no values have been missed
+ if missed == 0 {
+ self.next = self.next.wrapping_add(1);
+ return Ok(guard);
}
+
+ guard.drop_no_rem_dec();
+ self.next = next;
+
+ return Err(TryRecvError::Lagged(missed));
}
self.next = self.next.wrapping_add(1);
+ // If the `CLOSED` bit it set on the slot, the channel is closed
+ //
+ // `try_rx_lock` could check for this and bail early. If it's return
+ // value was changed to represent the state of the lock, it could
+ // match on being closed, empty, or available for reading.
+ if slot.lock.load(SeqCst) & CLOSED == CLOSED {
+ guard.drop_no_rem_dec();
+ return Err(TryRecvError::Closed);
+ }
+
Ok(guard)
}
}
@@ -909,7 +948,6 @@ impl<T> Drop for Receiver<T> {
while self.next != until {
match self.recv_ref(true) {
- // Ignore the value
Ok(_) => {}
// The channel is closed
Err(TryRecvError::Closed) => break,
@@ -954,13 +992,15 @@ impl<T> Slot<T> {
let mut curr = self.lock.load(SeqCst);
loop {
- if curr & 1 == 1 {
+ if curr & WRITER == WRITER {
// Locked by sender
return false;
}
- // Only increment (by 2) if the LSB "lock" bit is not set.
- let res = self.lock.compare_exchange(curr, curr + 2, SeqCst, SeqCst);
+ // Only increment (by `READER`) if the `WRITER` bit is not set.
+ let res = self
+ .lock
+ .compare_exchange(curr, curr + READER, SeqCst, SeqCst);
match res {
Ok(_) => return true,
@@ -978,7 +1018,7 @@ impl<T> Slot<T> {
}
}
- if 1 == self.lock.fetch_sub(2, SeqCst) - 2 {
+ if WRITER == self.lock.fetch_sub(READER, SeqCst) - READER {
// First acquire the lock to make sure our sender is waiting on the
// condition variable, otherwise the notification could be lost.
mem::drop(tail.lock().unwrap());