summaryrefslogtreecommitdiffstats
path: root/tokio/src/sync/watch.rs
diff options
context:
space:
mode:
authorCarl Lerche <me@carllerche.com>2020-09-11 15:14:45 -0700
committerGitHub <noreply@github.com>2020-09-11 15:14:45 -0700
commit2bc9a4815259c8ff4daa5e24f128ec826970d17f (patch)
treec075e4d97a145ce104cfc8ee39d8d06acece5c13 /tokio/src/sync/watch.rs
parentc5a9ede157691ac5ca15283735bd666c6b016188 (diff)
sync: tweak `watch` API (#2814)
Decouples getting the latest `watch` value from receiving the change notification. The `Receiver` async method becomes `Receiver::changed()`. The latest value is obtained from `Receiver::borrow()`. The implementation is updated to use `Notify`. This requires adding `Notify::notify_waiters`. This method is generally useful but is kept private for now.
Diffstat (limited to 'tokio/src/sync/watch.rs')
-rw-r--r--tokio/src/sync/watch.rs353
1 files changed, 155 insertions, 198 deletions
diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs
index f6660b6e..7d1ac9e8 100644
--- a/tokio/src/sync/watch.rs
+++ b/tokio/src/sync/watch.rs
@@ -6,13 +6,11 @@
//!
//! # Usage
//!
-//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are
-//! the producer and sender halves of the channel. The channel is
-//! created with an initial value. [`Receiver::recv`] will always
-//! be ready upon creation and will yield either this initial value or
-//! the latest value that has been sent by `Sender`.
-//!
-//! Calls to [`Receiver::recv`] will always yield the latest value.
+//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer
+//! and sender halves of the channel. The channel is created with an initial
+//! value. The **latest** value stored in the channel is accessed with
+//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new
+//! value to sent by the [`Sender`] half.
//!
//! # Examples
//!
@@ -23,8 +21,8 @@
//! let (tx, mut rx) = watch::channel("hello");
//!
//! tokio::spawn(async move {
-//! while let Some(value) = Some(rx.recv().await) {
-//! println!("received = {:?}", value);
+//! while rx.changed().await.is_ok() {
+//! println!("received = {:?}", *rx.borrow());
//! }
//! });
//!
@@ -47,20 +45,17 @@
//!
//! [`Sender`]: crate::sync::watch::Sender
//! [`Receiver`]: crate::sync::watch::Receiver
-//! [`Receiver::recv`]: crate::sync::watch::Receiver::recv
+//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed
+//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow
//! [`channel`]: crate::sync::watch::channel
//! [`Sender::closed`]: crate::sync::watch::Sender::closed
-use crate::future::poll_fn;
-use crate::sync::task::AtomicWaker;
+use crate::sync::Notify;
-use fnv::FnvHashSet;
use std::ops;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
-use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak};
-use std::task::Poll::{Pending, Ready};
-use std::task::{Context, Poll};
+use std::sync::{Arc, RwLock, RwLockReadGuard};
/// Receives values from the associated [`Sender`](struct@Sender).
///
@@ -70,8 +65,8 @@ pub struct Receiver<T> {
/// Pointer to the shared state
shared: Arc<Shared<T>>,
- /// Pointer to the watcher's internal state
- inner: Watcher,
+ /// Last observed version
+ version: usize,
}
/// Sends values to the associated [`Receiver`](struct@Receiver).
@@ -79,7 +74,7 @@ pub struct Receiver<T> {
/// Instances are created by the [`channel`](fn@channel) function.
#[derive(Debug)]
pub struct Sender<T> {
- shared: Weak<Shared<T>>,
+ shared: Arc<Shared<T>>,
}
/// Returns a reference to the inner value
@@ -92,6 +87,27 @@ pub struct Ref<'a, T> {
inner: RwLockReadGuard<'a, T>,
}
+#[derive(Debug)]
+struct Shared<T> {
+ /// The most recent value
+ value: RwLock<T>,
+
+ /// The current version
+ ///
+ /// The lowest bit represents a "closed" state. The rest of the bits
+ /// represent the current version.
+ version: AtomicUsize,
+
+ /// Tracks the number of `Receiver` instances
+ ref_count_rx: AtomicUsize,
+
+ /// Notifies waiting receivers that the value changed.
+ notify_rx: Notify,
+
+ /// Notifies any task listening for `Receiver` dropped events
+ notify_tx: Notify,
+}
+
pub mod error {
//! Watch error types
@@ -112,37 +128,20 @@ pub mod error {
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
-}
-
-#[derive(Debug)]
-struct Shared<T> {
- /// The most recent value
- value: RwLock<T>,
-
- /// The current version
- ///
- /// The lowest bit represents a "closed" state. The rest of the bits
- /// represent the current version.
- version: AtomicUsize,
- /// All watchers
- watchers: Mutex<Watchers>,
-
- /// Task to notify when all watchers drop
- cancel: AtomicWaker,
-}
+ /// Error produced when receiving a change notification.
+ #[derive(Debug)]
+ pub struct RecvError(pub(super) ());
-type Watchers = FnvHashSet<Watcher>;
+ // ===== impl RecvError =====
-/// The watcher's ID is based on the Arc's pointer.
-#[derive(Clone, Debug)]
-struct Watcher(Arc<WatchInner>);
+ impl fmt::Display for RecvError {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(fmt, "channel closed")
+ }
+ }
-#[derive(Debug)]
-struct WatchInner {
- /// Last observed version
- version: AtomicUsize,
- waker: AtomicWaker,
+ impl std::error::Error for RecvError {}
}
const CLOSED: usize = 1;
@@ -162,8 +161,8 @@ const CLOSED: usize = 1;
/// let (tx, mut rx) = watch::channel("hello");
///
/// tokio::spawn(async move {
-/// while let Some(value) = Some(rx.recv().await) {
-/// println!("received = {:?}", value);
+/// while rx.changed().await.is_ok() {
+/// println!("received = {:?}", *rx.borrow());
/// }
/// });
///
@@ -174,29 +173,20 @@ const CLOSED: usize = 1;
///
/// [`Sender`]: struct@Sender
/// [`Receiver`]: struct@Receiver
-pub fn channel<T: Clone>(init: T) -> (Sender<T>, Receiver<T>) {
- const VERSION_0: usize = 0;
- const VERSION_1: usize = 2;
-
- // We don't start knowing VERSION_1
- let inner = Watcher::new_version(VERSION_0);
-
- // Insert the watcher
- let mut watchers = FnvHashSet::with_capacity_and_hasher(0, Default::default());
- watchers.insert(inner.clone());
-
+pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared {
value: RwLock::new(init),
- version: AtomicUsize::new(VERSION_1),
- watchers: Mutex::new(watchers),
- cancel: AtomicWaker::new(),
+ version: AtomicUsize::new(0),
+ ref_count_rx: AtomicUsize::new(1),
+ notify_rx: Notify::new(),
+ notify_tx: Notify::new(),
});
let tx = Sender {
- shared: Arc::downgrade(&shared),
+ shared: shared.clone(),
};
- let rx = Receiver { shared, inner };
+ let rx = Receiver { shared, version: 0 };
(tx, rx)
}
@@ -221,41 +211,13 @@ impl<T> Receiver<T> {
Ref { inner }
}
- // TODO: document
- #[doc(hidden)]
- pub fn poll_recv_ref<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll<Ref<'a, T>> {
- // Make sure the task is up to date
- self.inner.waker.register_by_ref(cx.waker());
-
- let state = self.shared.version.load(SeqCst);
- let version = state & !CLOSED;
-
- if self.inner.version.swap(version, Relaxed) != version {
- let inner = self.shared.value.read().unwrap();
-
- return Ready(Ref { inner });
- }
-
- if CLOSED == state & CLOSED {
- // The `Store` handle has been dropped.
- let inner = self.shared.value.read().unwrap();
-
- return Ready(Ref { inner });
- }
-
- Pending
- }
-}
-
-impl<T: Clone> Receiver<T> {
- /// Attempts to clone the latest value sent via the channel.
+ /// Wait for a change notification
///
- /// If this is the first time the function is called on a `Receiver`
- /// instance, then the function completes immediately with the **current**
- /// value held by the channel. On the next call, the function waits until
- /// a new value is sent in the channel.
+ /// Returns when a new value has been sent by the [`Sender`] since the last
+ /// time `changed()` was called. When the `Sender` half is dropped, `Err` is
+ /// returned.
///
- /// `None` is returned if the `Sender` half is dropped.
+ /// [`Sender`]: struct@Sender
///
/// # Examples
///
@@ -266,79 +228,110 @@ impl<T: Clone> Receiver<T> {
/// async fn main() {
/// let (tx, mut rx) = watch::channel("hello");
///
- /// let v = rx.recv().await;
- /// assert_eq!(v, "hello");
- ///
/// tokio::spawn(async move {
/// tx.send("goodbye").unwrap();
/// });
///
- /// // Waits for the new task to spawn and send the value.
- /// let v = rx.recv().await;
- /// assert_eq!(v, "goodbye");
+ /// assert!(rx.changed().await.is_ok());
+ /// assert_eq!(*rx.borrow(), "goodbye");
///
- /// let v = rx.recv().await;
- /// assert_eq!(v, "goodbye");
+ /// // The `tx` handle has been dropped
+ /// assert!(rx.changed().await.is_err());
/// }
/// ```
- pub async fn recv(&mut self) -> T {
- poll_fn(|cx| {
- let v_ref = ready!(self.poll_recv_ref(cx));
- Poll::Ready((*v_ref).clone())
+ pub async fn changed(&mut self) -> Result<(), error::RecvError> {
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::task::Poll;
+
+ // In order to avoid a race condition, we first request a notification,
+ // **then** check the current value's version. If a new version exists,
+ // the notification request is dropped. Requesting the notification
+ // requires polling the future once.
+ let notified = self.shared.notify_rx.notified();
+ pin!(notified);
+
+ // Polling the future once is guaranteed to return `Pending` as `watch`
+ // only notifies using `notify_waiters`.
+ crate::future::poll_fn(|cx| {
+ let res = Pin::new(&mut notified).poll(cx);
+ assert!(!res.is_ready());
+ Poll::Ready(())
})
- .await
+ .await;
+
+ if let Some(ret) = maybe_changed(&self.shared, &mut self.version) {
+ return ret;
+ }
+
+ notified.await;
+
+ maybe_changed(&self.shared, &mut self.version)
+ .expect("[bug] failed to observe change after notificaton.")
}
}
-#[cfg(feature = "stream")]
-impl<T: Clone> crate::stream::Stream for Receiver<T> {
- type Item = T;
-
- fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
- let v_ref = ready!(self.poll_recv_ref(cx));
+fn maybe_changed<T>(
+ shared: &Shared<T>,
+ version: &mut usize,
+) -> Option<Result<(), error::RecvError>> {
+ // Load the version from the state
+ let state = shared.version.load(SeqCst);
+ let new_version = state & !CLOSED;
+
+ if *version != new_version {
+ // Observe the new version and return
+ *version = new_version;
+ return Some(Ok(()));
+ }
- Poll::Ready(Some((*v_ref).clone()))
+ if CLOSED == state & CLOSED {
+ // All receivers have dropped.
+ return Some(Err(error::RecvError(())));
}
+
+ None
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
- let ver = self.inner.version.load(Relaxed);
- let inner = Watcher::new_version(ver);
+ let version = self.version;
let shared = self.shared.clone();
- shared.watchers.lock().unwrap().insert(inner.clone());
+ // No synchronization necessary as this is only used as a counter and
+ // not memory access.
+ shared.ref_count_rx.fetch_add(1, Relaxed);
- Receiver { shared, inner }
+ Receiver { version, shared }
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
- self.shared.watchers.lock().unwrap().remove(&self.inner);
+ // No synchronization necessary as this is only used as a counter and
+ // not memory access.
+ if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) {
+ // This is the last `Receiver` handle, tasks waiting on `Sender::closed()`
+ self.shared.notify_tx.notify_waiters();
+ }
}
}
impl<T> Sender<T> {
/// Sends a new value via the channel, notifying all receivers.
pub fn send(&self, value: T) -> Result<(), error::SendError<T>> {
- let shared = match self.shared.upgrade() {
- Some(shared) => shared,
- // All `Watch` handles have been canceled
- None => return Err(error::SendError { inner: value }),
- };
-
- // Replace the value
- {
- let mut lock = shared.value.write().unwrap();
- *lock = value;
+ // This is pretty much only useful as a hint anyway, so synchronization isn't critical.
+ if 0 == self.shared.ref_count_rx.load(Relaxed) {
+ return Err(error::SendError { inner: value });
}
+ *self.shared.value.write().unwrap() = value;
+
// Update the version. 2 is used so that the CLOSED bit is not set.
- shared.version.fetch_add(2, SeqCst);
+ self.shared.version.fetch_add(2, SeqCst);
// Notify all watchers
- notify_all(&*shared);
+ self.shared.notify_rx.notify_waiters();
Ok(())
}
@@ -347,37 +340,42 @@ impl<T> Sender<T> {
///
/// This allows the producer to get notified when interest in the produced
/// values is canceled and immediately stop doing work.
- pub async fn closed(&mut self) {
- poll_fn(|cx| self.poll_close(cx)).await
- }
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::watch;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, rx) = watch::channel("hello");
+ ///
+ /// tokio::spawn(async move {
+ /// // use `rx`
+ /// drop(rx);
+ /// });
+ ///
+ /// // Waits for `rx` to drop
+ /// tx.closed().await;
+ /// println!("the `rx` handles dropped")
+ /// }
+ /// ```
+ pub async fn closed(&self) {
+ let notified = self.shared.notify_tx.notified();
- fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> {
- match self.shared.upgrade() {
- Some(shared) => {
- shared.cancel.register_by_ref(cx.waker());
- Pending
- }
- None => Ready(()),
+ if self.shared.ref_count_rx.load(Relaxed) == 0 {
+ return;
}
- }
-}
-
-/// Notifies all watchers of a change
-fn notify_all<T>(shared: &Shared<T>) {
- let watchers = shared.watchers.lock().unwrap();
- for watcher in watchers.iter() {
- // Notify the task
- watcher.waker.wake();
+ notified.await;
+ debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed));
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
- if let Some(shared) = self.shared.upgrade() {
- shared.version.fetch_or(CLOSED, SeqCst);
- notify_all(&*shared);
- }
+ self.shared.version.fetch_or(CLOSED, SeqCst);
+ self.shared.notify_rx.notify_waiters();
}
}
@@ -390,44 +388,3 @@ impl<T> ops::Deref for Ref<'_, T> {
self.inner.deref()
}
}
-
-// ===== impl Shared =====
-
-impl<T> Drop for Shared<T> {
- fn drop(&mut self) {
- self.cancel.wake();
- }
-}
-
-// ===== impl Watcher =====
-
-impl Watcher {
- fn new_version(version: usize) -> Self {
- Watcher(Arc::new(WatchInner {
- version: AtomicUsize::new(version),
- waker: AtomicWaker::new(),
- }))
- }
-}
-
-impl std::cmp::PartialEq for Watcher {
- fn eq(&self, other: &Watcher) -> bool {
- Arc::ptr_eq(&self.0, &other.0)
- }
-}
-
-impl std::cmp::Eq for Watcher {}
-
-impl std::hash::Hash for Watcher {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- (&*self.0 as *const WatchInner).hash(state)
- }
-}
-
-impl std::ops::Deref for Watcher {
- type Target = WatchInner;
-
- fn deref(&self) -> &Self::Target {
- &self.0
- }
-}