summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkalcutter <31195032+kalcutter@users.noreply.github.com>2020-01-22 22:59:05 +0100
committerCarl Lerche <me@carllerche.com>2020-01-22 13:59:05 -0800
commitf9ea576ccae5beffeaa2f2c48c2c0d2f9449673b (patch)
treef726788538b57d83f5e223648d7d21622ec86c41
parent7f580071f3e5d475db200d2101ff35be0b4f6efe (diff)
sync: fix broadcast bugs (#2135)
Make sure the tail mutex is acquired when `condvar` is notified, otherwise the wakeup may be lost and the sender could be left waiting. Use `notify_all()` instead of `notify_one()` to ensure that the correct sender is woken. Finally, only do any of this when there are no more readers left. Additionally, calling `send()` is buggy and may cause a panic when the slot has another pending send.
-rw-r--r--tokio/src/sync/broadcast.rs49
-rw-r--r--tokio/src/sync/tests/loom_broadcast.rs40
2 files changed, 73 insertions, 16 deletions
diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs
index e020a3c3..35854811 100644
--- a/tokio/src/sync/broadcast.rs
+++ b/tokio/src/sync/broadcast.rs
@@ -214,9 +214,8 @@ pub enum RecvError {
/// be sent.
Closed,
- /// The receiver lagged too far behind and has been forcibly disconnected.
- /// Attempting to receive again will return the oldest message still
- /// retained by the channel.
+ /// The receiver lagged too far behind. Attempting to receive again will
+ /// return the oldest message still retained by the channel.
///
/// Includes the number of skipped messages.
Lagged(u64),
@@ -274,9 +273,9 @@ struct Tail {
rx_cnt: usize,
}
-/// Node in the linked list
+/// Slot in the buffer
struct Slot<T> {
- /// Remaining numer of senders that are expected to see this value.
+ /// Remaining number of receivers that are expected to see this value.
///
/// When this goes to zero, the value is released.
rem: AtomicUsize,
@@ -314,6 +313,7 @@ struct WaitNode {
struct RecvGuard<'a, T> {
slot: &'a Slot<T>,
+ tail: &'a Mutex<Tail>,
condvar: &'a Condvar,
}
@@ -579,17 +579,27 @@ impl<T> Sender<T> {
let slot = &self.shared.buffer[idx];
// Acquire the write lock
- let mut prev = slot.lock.fetch_add(1, SeqCst);
+ let mut prev = slot.lock.fetch_or(1, SeqCst);
while prev & !1 != 0 {
// Concurrent readers, we must go to sleep
tail = self.shared.condvar.wait(tail).unwrap();
prev = slot.lock.load(SeqCst);
+
+ if prev & 1 == 0 {
+ // The writer lock bit was cleared while this thread was
+ // sleeping. This can only happen if a newer write happened on
+ // this slot by another thread. Bail early as an optimization,
+ // there is nothing left to do.
+ return Ok(rem);
+ }
}
- // Release the mutex
- drop(tail);
+ if tail.pos.wrapping_sub(pos) > self.shared.buffer.len() as u64 {
+ // There is a newer pending write to the same slot.
+ return Ok(rem);
+ }
// Slot lock acquired
slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos });
@@ -601,6 +611,11 @@ impl<T> Sender<T> {
// Release the slot lock
slot.lock.store(0, SeqCst);
+ // Release the mutex. This must happen after the slot lock is released,
+ // otherwise the writer lock bit could be cleared while another thread
+ // is in the critical section.
+ drop(tail);
+
// Notify waiting receivers
self.notify_rx();
@@ -665,6 +680,7 @@ impl<T> Receiver<T> {
let guard = RecvGuard {
slot,
+ tail: &self.shared.tail,
condvar: &self.shared.condvar,
};
@@ -952,7 +968,7 @@ impl<T> Slot<T> {
}
}
- fn rx_unlock(&self, condvar: &Condvar, rem_dec: bool) {
+ fn rx_unlock(&self, tail: &Mutex<Tail>, condvar: &Condvar, rem_dec: bool) {
if rem_dec {
// Decrement the remaining counter
if 1 == self.rem.fetch_sub(1, SeqCst) {
@@ -961,11 +977,12 @@ impl<T> Slot<T> {
}
}
- let prev = self.lock.fetch_sub(2, SeqCst);
-
- if prev & 1 == 1 {
- // Sender waiting for lock
- condvar.notify_one();
+ if 1 == self.lock.fetch_sub(2, SeqCst) - 2 {
+ // First acquire the lock to make sure our sender is waiting on the
+ // condition variable, otherwise the notification could be lost.
+ let _ = tail.lock().unwrap();
+ // Wake up senders
+ condvar.notify_all();
}
}
}
@@ -985,7 +1002,7 @@ impl<'a, T> RecvGuard<'a, T> {
fn drop_no_rem_dec(self) {
use std::mem;
- self.slot.rx_unlock(self.condvar, false);
+ self.slot.rx_unlock(self.tail, self.condvar, false);
mem::forget(self);
}
@@ -993,7 +1010,7 @@ impl<'a, T> RecvGuard<'a, T> {
impl<'a, T> Drop for RecvGuard<'a, T> {
fn drop(&mut self) {
- self.slot.rx_unlock(self.condvar, true)
+ self.slot.rx_unlock(self.tail, self.condvar, true)
}
}
diff --git a/tokio/src/sync/tests/loom_broadcast.rs b/tokio/src/sync/tests/loom_broadcast.rs
index da61563b..da12fb9f 100644
--- a/tokio/src/sync/tests/loom_broadcast.rs
+++ b/tokio/src/sync/tests/loom_broadcast.rs
@@ -6,6 +6,46 @@ use loom::sync::Arc;
use loom::thread;
use tokio_test::{assert_err, assert_ok};
+#[test]
+fn broadcast_send() {
+ loom::model(|| {
+ let (tx1, mut rx) = broadcast::channel(2);
+ let tx1 = Arc::new(tx1);
+ let tx2 = tx1.clone();
+
+ let th1 = thread::spawn(move || {
+ block_on(async {
+ assert_ok!(tx1.send("one"));
+ assert_ok!(tx1.send("two"));
+ assert_ok!(tx1.send("three"));
+ });
+ });
+
+ let th2 = thread::spawn(move || {
+ block_on(async {
+ assert_ok!(tx2.send("eins"));
+ assert_ok!(tx2.send("zwei"));
+ assert_ok!(tx2.send("drei"));
+ });
+ });
+
+ block_on(async {
+ let mut num = 0;
+ loop {
+ match rx.recv().await {
+ Ok(_) => num += 1,
+ Err(Closed) => break,
+ Err(Lagged(n)) => num += n as usize,
+ }
+ }
+ assert_eq!(num, 6);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ });
+}
+
// An `Arc` is used as the value in order to detect memory leaks.
#[test]
fn broadcast_two() {