diff options
Diffstat (limited to 'tokio/src/sync/broadcast.rs')
-rw-r--r-- | tokio/src/sync/broadcast.rs | 90 |
1 files changed, 65 insertions, 25 deletions
diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 05a58070..abc4974a 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -272,6 +272,9 @@ struct Tail { /// Number of active receivers rx_cnt: usize, + + /// True if the channel is closed + closed: bool, } /// Slot in the buffer @@ -319,7 +322,10 @@ struct RecvGuard<'a, T> { } /// Max number of receivers. Reserve space to lock. -const MAX_RECEIVERS: usize = usize::MAX >> 1; +const MAX_RECEIVERS: usize = usize::MAX >> 2; +const CLOSED: usize = 1; +const WRITER: usize = 2; +const READER: usize = 4; /// Create a bounded, multi-producer, multi-consumer channel where each sent /// value is broadcasted to all active receivers. @@ -389,7 +395,11 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { let shared = Arc::new(Shared { buffer: buffer.into_boxed_slice(), mask: capacity - 1, - tail: Mutex::new(Tail { pos: 0, rx_cnt: 1 }), + tail: Mutex::new(Tail { + pos: 0, + rx_cnt: 1, + closed: false, + }), condvar: Condvar::new(), wait_stack: AtomicPtr::new(ptr::null_mut()), num_tx: AtomicUsize::new(1), @@ -580,15 +590,15 @@ impl<T> Sender<T> { let slot = &self.shared.buffer[idx]; // Acquire the write lock - let mut prev = slot.lock.fetch_or(1, SeqCst); + let mut prev = slot.lock.fetch_or(WRITER, SeqCst); - while prev & !1 != 0 { + while prev & !WRITER != 0 { // Concurrent readers, we must go to sleep tail = self.shared.condvar.wait(tail).unwrap(); prev = slot.lock.load(SeqCst); - if prev & 1 == 0 { + if prev & WRITER == 0 { // The writer lock bit was cleared while this thread was // sleeping. This can only happen if a newer write happened on // this slot by another thread. Bail early as an optimization, @@ -604,13 +614,18 @@ impl<T> Sender<T> { // Slot lock acquired slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos }); - slot.write.val.with_mut(|ptr| unsafe { *ptr = value }); // Set remaining receivers slot.rem.store(rem, SeqCst); - // Release the slot lock - slot.lock.store(0, SeqCst); + // Set the closed bit if the value is `None`; otherwise write the value + if value.is_none() { + tail.closed = true; + slot.lock.store(CLOSED, SeqCst); + } else { + slot.write.val.with_mut(|ptr| unsafe { *ptr = value }); + slot.lock.store(0, SeqCst); + } // Release the mutex. This must happen after the slot lock is released, // otherwise the writer lock bit could be cleared while another thread @@ -688,28 +703,52 @@ impl<T> Receiver<T> { if guard.pos() != self.next { let pos = guard.pos(); - guard.drop_no_rem_dec(); - + // The receiver has read all current values in the channel if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { + guard.drop_no_rem_dec(); return Err(TryRecvError::Empty); - } else { - let tail = self.shared.tail.lock().unwrap(); + } - // `tail.pos` points to the slot the **next** send writes to. - // Because a receiver is lagging, this slot also holds the - // oldest value. To make the positions match, we subtract the - // capacity. - let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64); - let missed = next.wrapping_sub(self.next); + let tail = self.shared.tail.lock().unwrap(); - self.next = next; + // `tail.pos` points to the slot that the **next** send writes to. If + // the channel is closed, the previous slot is the oldest value. + let mut adjust = 0; + if tail.closed { + adjust = 1 + } + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); - return Err(TryRecvError::Lagged(missed)); + let missed = next.wrapping_sub(self.next); + + drop(tail); + + // The receiver is slow but no values have been missed + if missed == 0 { + self.next = self.next.wrapping_add(1); + return Ok(guard); } + + guard.drop_no_rem_dec(); + self.next = next; + + return Err(TryRecvError::Lagged(missed)); } self.next = self.next.wrapping_add(1); + // If the `CLOSED` bit it set on the slot, the channel is closed + // + // `try_rx_lock` could check for this and bail early. If it's return + // value was changed to represent the state of the lock, it could + // match on being closed, empty, or available for reading. + if slot.lock.load(SeqCst) & CLOSED == CLOSED { + guard.drop_no_rem_dec(); + return Err(TryRecvError::Closed); + } + Ok(guard) } } @@ -909,7 +948,6 @@ impl<T> Drop for Receiver<T> { while self.next != until { match self.recv_ref(true) { - // Ignore the value Ok(_) => {} // The channel is closed Err(TryRecvError::Closed) => break, @@ -954,13 +992,15 @@ impl<T> Slot<T> { let mut curr = self.lock.load(SeqCst); loop { - if curr & 1 == 1 { + if curr & WRITER == WRITER { // Locked by sender return false; } - // Only increment (by 2) if the LSB "lock" bit is not set. - let res = self.lock.compare_exchange(curr, curr + 2, SeqCst, SeqCst); + // Only increment (by `READER`) if the `WRITER` bit is not set. + let res = self + .lock + .compare_exchange(curr, curr + READER, SeqCst, SeqCst); match res { Ok(_) => return true, @@ -978,7 +1018,7 @@ impl<T> Slot<T> { } } - if 1 == self.lock.fetch_sub(2, SeqCst) - 2 { + if WRITER == self.lock.fetch_sub(READER, SeqCst) - READER { // First acquire the lock to make sure our sender is waiting on the // condition variable, otherwise the notification could be lost. mem::drop(tail.lock().unwrap()); |