diff options
author | kalcutter <31195032+kalcutter@users.noreply.github.com> | 2020-01-22 22:59:05 +0100 |
---|---|---|
committer | Carl Lerche <me@carllerche.com> | 2020-01-22 13:59:05 -0800 |
commit | f9ea576ccae5beffeaa2f2c48c2c0d2f9449673b (patch) | |
tree | f726788538b57d83f5e223648d7d21622ec86c41 | |
parent | 7f580071f3e5d475db200d2101ff35be0b4f6efe (diff) |
sync: fix broadcast bugs (#2135)
Make sure the tail mutex is acquired when `condvar` is notified,
otherwise the wakeup may be lost and the sender could be left waiting.
Use `notify_all()` instead of `notify_one()` to ensure that the correct
sender is woken. Finally, only do any of this when there are no more
readers left.
Additionally, calling `send()` is buggy and may cause a panic when
the slot has another pending send.
-rw-r--r-- | tokio/src/sync/broadcast.rs | 49 | ||||
-rw-r--r-- | tokio/src/sync/tests/loom_broadcast.rs | 40 |
2 files changed, 73 insertions, 16 deletions
diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index e020a3c3..35854811 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -214,9 +214,8 @@ pub enum RecvError { /// be sent. Closed, - /// The receiver lagged too far behind and has been forcibly disconnected. - /// Attempting to receive again will return the oldest message still - /// retained by the channel. + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. /// /// Includes the number of skipped messages. Lagged(u64), @@ -274,9 +273,9 @@ struct Tail { rx_cnt: usize, } -/// Node in the linked list +/// Slot in the buffer struct Slot<T> { - /// Remaining numer of senders that are expected to see this value. + /// Remaining number of receivers that are expected to see this value. /// /// When this goes to zero, the value is released. rem: AtomicUsize, @@ -314,6 +313,7 @@ struct WaitNode { struct RecvGuard<'a, T> { slot: &'a Slot<T>, + tail: &'a Mutex<Tail>, condvar: &'a Condvar, } @@ -579,17 +579,27 @@ impl<T> Sender<T> { let slot = &self.shared.buffer[idx]; // Acquire the write lock - let mut prev = slot.lock.fetch_add(1, SeqCst); + let mut prev = slot.lock.fetch_or(1, SeqCst); while prev & !1 != 0 { // Concurrent readers, we must go to sleep tail = self.shared.condvar.wait(tail).unwrap(); prev = slot.lock.load(SeqCst); + + if prev & 1 == 0 { + // The writer lock bit was cleared while this thread was + // sleeping. This can only happen if a newer write happened on + // this slot by another thread. Bail early as an optimization, + // there is nothing left to do. + return Ok(rem); + } } - // Release the mutex - drop(tail); + if tail.pos.wrapping_sub(pos) > self.shared.buffer.len() as u64 { + // There is a newer pending write to the same slot. + return Ok(rem); + } // Slot lock acquired slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos }); @@ -601,6 +611,11 @@ impl<T> Sender<T> { // Release the slot lock slot.lock.store(0, SeqCst); + // Release the mutex. This must happen after the slot lock is released, + // otherwise the writer lock bit could be cleared while another thread + // is in the critical section. + drop(tail); + // Notify waiting receivers self.notify_rx(); @@ -665,6 +680,7 @@ impl<T> Receiver<T> { let guard = RecvGuard { slot, + tail: &self.shared.tail, condvar: &self.shared.condvar, }; @@ -952,7 +968,7 @@ impl<T> Slot<T> { } } - fn rx_unlock(&self, condvar: &Condvar, rem_dec: bool) { + fn rx_unlock(&self, tail: &Mutex<Tail>, condvar: &Condvar, rem_dec: bool) { if rem_dec { // Decrement the remaining counter if 1 == self.rem.fetch_sub(1, SeqCst) { @@ -961,11 +977,12 @@ impl<T> Slot<T> { } } - let prev = self.lock.fetch_sub(2, SeqCst); - - if prev & 1 == 1 { - // Sender waiting for lock - condvar.notify_one(); + if 1 == self.lock.fetch_sub(2, SeqCst) - 2 { + // First acquire the lock to make sure our sender is waiting on the + // condition variable, otherwise the notification could be lost. + let _ = tail.lock().unwrap(); + // Wake up senders + condvar.notify_all(); } } } @@ -985,7 +1002,7 @@ impl<'a, T> RecvGuard<'a, T> { fn drop_no_rem_dec(self) { use std::mem; - self.slot.rx_unlock(self.condvar, false); + self.slot.rx_unlock(self.tail, self.condvar, false); mem::forget(self); } @@ -993,7 +1010,7 @@ impl<'a, T> RecvGuard<'a, T> { impl<'a, T> Drop for RecvGuard<'a, T> { fn drop(&mut self) { - self.slot.rx_unlock(self.condvar, true) + self.slot.rx_unlock(self.tail, self.condvar, true) } } diff --git a/tokio/src/sync/tests/loom_broadcast.rs b/tokio/src/sync/tests/loom_broadcast.rs index da61563b..da12fb9f 100644 --- a/tokio/src/sync/tests/loom_broadcast.rs +++ b/tokio/src/sync/tests/loom_broadcast.rs @@ -6,6 +6,46 @@ use loom::sync::Arc; use loom::thread; use tokio_test::{assert_err, assert_ok}; +#[test] +fn broadcast_send() { + loom::model(|| { + let (tx1, mut rx) = broadcast::channel(2); + let tx1 = Arc::new(tx1); + let tx2 = tx1.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + assert_ok!(tx1.send("one")); + assert_ok!(tx1.send("two")); + assert_ok!(tx1.send("three")); + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + assert_ok!(tx2.send("eins")); + assert_ok!(tx2.send("zwei")); + assert_ok!(tx2.send("drei")); + }); + }); + + block_on(async { + let mut num = 0; + loop { + match rx.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + assert_eq!(num, 6); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + // An `Arc` is used as the value in order to detect memory leaks. #[test] fn broadcast_two() { |