From 2bc9a4815259c8ff4daa5e24f128ec826970d17f Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Fri, 11 Sep 2020 15:14:45 -0700 Subject: sync: tweak `watch` API (#2814) Decouples getting the latest `watch` value from receiving the change notification. The `Receiver` async method becomes `Receiver::changed()`. The latest value is obtained from `Receiver::borrow()`. The implementation is updated to use `Notify`. This requires adding `Notify::notify_waiters`. This method is generally useful but is kept private for now. --- tokio/src/sync/watch.rs | 353 +++++++++++++++++++++--------------------------- 1 file changed, 155 insertions(+), 198 deletions(-) (limited to 'tokio/src/sync/watch.rs') diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index f6660b6e..7d1ac9e8 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -6,13 +6,11 @@ //! //! # Usage //! -//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are -//! the producer and sender halves of the channel. The channel is -//! created with an initial value. [`Receiver::recv`] will always -//! be ready upon creation and will yield either this initial value or -//! the latest value that has been sent by `Sender`. -//! -//! Calls to [`Receiver::recv`] will always yield the latest value. +//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer +//! and sender halves of the channel. The channel is created with an initial +//! value. The **latest** value stored in the channel is accessed with +//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new +//! value to sent by the [`Sender`] half. //! //! # Examples //! @@ -23,8 +21,8 @@ //! let (tx, mut rx) = watch::channel("hello"); //! //! tokio::spawn(async move { -//! while let Some(value) = Some(rx.recv().await) { -//! println!("received = {:?}", value); +//! while rx.changed().await.is_ok() { +//! println!("received = {:?}", *rx.borrow()); //! } //! }); //! @@ -47,20 +45,17 @@ //! //! [`Sender`]: crate::sync::watch::Sender //! [`Receiver`]: crate::sync::watch::Receiver -//! [`Receiver::recv`]: crate::sync::watch::Receiver::recv +//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed +//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow //! [`channel`]: crate::sync::watch::channel //! [`Sender::closed`]: crate::sync::watch::Sender::closed -use crate::future::poll_fn; -use crate::sync::task::AtomicWaker; +use crate::sync::Notify; -use fnv::FnvHashSet; use std::ops; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; -use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak}; -use std::task::Poll::{Pending, Ready}; -use std::task::{Context, Poll}; +use std::sync::{Arc, RwLock, RwLockReadGuard}; /// Receives values from the associated [`Sender`](struct@Sender). /// @@ -70,8 +65,8 @@ pub struct Receiver { /// Pointer to the shared state shared: Arc>, - /// Pointer to the watcher's internal state - inner: Watcher, + /// Last observed version + version: usize, } /// Sends values to the associated [`Receiver`](struct@Receiver). @@ -79,7 +74,7 @@ pub struct Receiver { /// Instances are created by the [`channel`](fn@channel) function. #[derive(Debug)] pub struct Sender { - shared: Weak>, + shared: Arc>, } /// Returns a reference to the inner value @@ -92,6 +87,27 @@ pub struct Ref<'a, T> { inner: RwLockReadGuard<'a, T>, } +#[derive(Debug)] +struct Shared { + /// The most recent value + value: RwLock, + + /// The current version + /// + /// The lowest bit represents a "closed" state. The rest of the bits + /// represent the current version. + version: AtomicUsize, + + /// Tracks the number of `Receiver` instances + ref_count_rx: AtomicUsize, + + /// Notifies waiting receivers that the value changed. + notify_rx: Notify, + + /// Notifies any task listening for `Receiver` dropped events + notify_tx: Notify, +} + pub mod error { //! Watch error types @@ -112,37 +128,20 @@ pub mod error { } impl std::error::Error for SendError {} -} - -#[derive(Debug)] -struct Shared { - /// The most recent value - value: RwLock, - - /// The current version - /// - /// The lowest bit represents a "closed" state. The rest of the bits - /// represent the current version. - version: AtomicUsize, - /// All watchers - watchers: Mutex, - - /// Task to notify when all watchers drop - cancel: AtomicWaker, -} + /// Error produced when receiving a change notification. + #[derive(Debug)] + pub struct RecvError(pub(super) ()); -type Watchers = FnvHashSet; + // ===== impl RecvError ===== -/// The watcher's ID is based on the Arc's pointer. -#[derive(Clone, Debug)] -struct Watcher(Arc); + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } -#[derive(Debug)] -struct WatchInner { - /// Last observed version - version: AtomicUsize, - waker: AtomicWaker, + impl std::error::Error for RecvError {} } const CLOSED: usize = 1; @@ -162,8 +161,8 @@ const CLOSED: usize = 1; /// let (tx, mut rx) = watch::channel("hello"); /// /// tokio::spawn(async move { -/// while let Some(value) = Some(rx.recv().await) { -/// println!("received = {:?}", value); +/// while rx.changed().await.is_ok() { +/// println!("received = {:?}", *rx.borrow()); /// } /// }); /// @@ -174,29 +173,20 @@ const CLOSED: usize = 1; /// /// [`Sender`]: struct@Sender /// [`Receiver`]: struct@Receiver -pub fn channel(init: T) -> (Sender, Receiver) { - const VERSION_0: usize = 0; - const VERSION_1: usize = 2; - - // We don't start knowing VERSION_1 - let inner = Watcher::new_version(VERSION_0); - - // Insert the watcher - let mut watchers = FnvHashSet::with_capacity_and_hasher(0, Default::default()); - watchers.insert(inner.clone()); - +pub fn channel(init: T) -> (Sender, Receiver) { let shared = Arc::new(Shared { value: RwLock::new(init), - version: AtomicUsize::new(VERSION_1), - watchers: Mutex::new(watchers), - cancel: AtomicWaker::new(), + version: AtomicUsize::new(0), + ref_count_rx: AtomicUsize::new(1), + notify_rx: Notify::new(), + notify_tx: Notify::new(), }); let tx = Sender { - shared: Arc::downgrade(&shared), + shared: shared.clone(), }; - let rx = Receiver { shared, inner }; + let rx = Receiver { shared, version: 0 }; (tx, rx) } @@ -221,41 +211,13 @@ impl Receiver { Ref { inner } } - // TODO: document - #[doc(hidden)] - pub fn poll_recv_ref<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll> { - // Make sure the task is up to date - self.inner.waker.register_by_ref(cx.waker()); - - let state = self.shared.version.load(SeqCst); - let version = state & !CLOSED; - - if self.inner.version.swap(version, Relaxed) != version { - let inner = self.shared.value.read().unwrap(); - - return Ready(Ref { inner }); - } - - if CLOSED == state & CLOSED { - // The `Store` handle has been dropped. - let inner = self.shared.value.read().unwrap(); - - return Ready(Ref { inner }); - } - - Pending - } -} - -impl Receiver { - /// Attempts to clone the latest value sent via the channel. + /// Wait for a change notification /// - /// If this is the first time the function is called on a `Receiver` - /// instance, then the function completes immediately with the **current** - /// value held by the channel. On the next call, the function waits until - /// a new value is sent in the channel. + /// Returns when a new value has been sent by the [`Sender`] since the last + /// time `changed()` was called. When the `Sender` half is dropped, `Err` is + /// returned. /// - /// `None` is returned if the `Sender` half is dropped. + /// [`Sender`]: struct@Sender /// /// # Examples /// @@ -266,79 +228,110 @@ impl Receiver { /// async fn main() { /// let (tx, mut rx) = watch::channel("hello"); /// - /// let v = rx.recv().await; - /// assert_eq!(v, "hello"); - /// /// tokio::spawn(async move { /// tx.send("goodbye").unwrap(); /// }); /// - /// // Waits for the new task to spawn and send the value. - /// let v = rx.recv().await; - /// assert_eq!(v, "goodbye"); + /// assert!(rx.changed().await.is_ok()); + /// assert_eq!(*rx.borrow(), "goodbye"); /// - /// let v = rx.recv().await; - /// assert_eq!(v, "goodbye"); + /// // The `tx` handle has been dropped + /// assert!(rx.changed().await.is_err()); /// } /// ``` - pub async fn recv(&mut self) -> T { - poll_fn(|cx| { - let v_ref = ready!(self.poll_recv_ref(cx)); - Poll::Ready((*v_ref).clone()) + pub async fn changed(&mut self) -> Result<(), error::RecvError> { + use std::future::Future; + use std::pin::Pin; + use std::task::Poll; + + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. Requesting the notification + // requires polling the future once. + let notified = self.shared.notify_rx.notified(); + pin!(notified); + + // Polling the future once is guaranteed to return `Pending` as `watch` + // only notifies using `notify_waiters`. + crate::future::poll_fn(|cx| { + let res = Pin::new(&mut notified).poll(cx); + assert!(!res.is_ready()); + Poll::Ready(()) }) - .await + .await; + + if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { + return ret; + } + + notified.await; + + maybe_changed(&self.shared, &mut self.version) + .expect("[bug] failed to observe change after notificaton.") } } -#[cfg(feature = "stream")] -impl crate::stream::Stream for Receiver { - type Item = T; - - fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let v_ref = ready!(self.poll_recv_ref(cx)); +fn maybe_changed( + shared: &Shared, + version: &mut usize, +) -> Option> { + // Load the version from the state + let state = shared.version.load(SeqCst); + let new_version = state & !CLOSED; + + if *version != new_version { + // Observe the new version and return + *version = new_version; + return Some(Ok(())); + } - Poll::Ready(Some((*v_ref).clone())) + if CLOSED == state & CLOSED { + // All receivers have dropped. + return Some(Err(error::RecvError(()))); } + + None } impl Clone for Receiver { fn clone(&self) -> Self { - let ver = self.inner.version.load(Relaxed); - let inner = Watcher::new_version(ver); + let version = self.version; let shared = self.shared.clone(); - shared.watchers.lock().unwrap().insert(inner.clone()); + // No synchronization necessary as this is only used as a counter and + // not memory access. + shared.ref_count_rx.fetch_add(1, Relaxed); - Receiver { shared, inner } + Receiver { version, shared } } } impl Drop for Receiver { fn drop(&mut self) { - self.shared.watchers.lock().unwrap().remove(&self.inner); + // No synchronization necessary as this is only used as a counter and + // not memory access. + if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) { + // This is the last `Receiver` handle, tasks waiting on `Sender::closed()` + self.shared.notify_tx.notify_waiters(); + } } } impl Sender { /// Sends a new value via the channel, notifying all receivers. pub fn send(&self, value: T) -> Result<(), error::SendError> { - let shared = match self.shared.upgrade() { - Some(shared) => shared, - // All `Watch` handles have been canceled - None => return Err(error::SendError { inner: value }), - }; - - // Replace the value - { - let mut lock = shared.value.write().unwrap(); - *lock = value; + // This is pretty much only useful as a hint anyway, so synchronization isn't critical. + if 0 == self.shared.ref_count_rx.load(Relaxed) { + return Err(error::SendError { inner: value }); } + *self.shared.value.write().unwrap() = value; + // Update the version. 2 is used so that the CLOSED bit is not set. - shared.version.fetch_add(2, SeqCst); + self.shared.version.fetch_add(2, SeqCst); // Notify all watchers - notify_all(&*shared); + self.shared.notify_rx.notify_waiters(); Ok(()) } @@ -347,37 +340,42 @@ impl Sender { /// /// This allows the producer to get notified when interest in the produced /// values is canceled and immediately stop doing work. - pub async fn closed(&mut self) { - poll_fn(|cx| self.poll_close(cx)).await - } + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// // use `rx` + /// drop(rx); + /// }); + /// + /// // Waits for `rx` to drop + /// tx.closed().await; + /// println!("the `rx` handles dropped") + /// } + /// ``` + pub async fn closed(&self) { + let notified = self.shared.notify_tx.notified(); - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { - match self.shared.upgrade() { - Some(shared) => { - shared.cancel.register_by_ref(cx.waker()); - Pending - } - None => Ready(()), + if self.shared.ref_count_rx.load(Relaxed) == 0 { + return; } - } -} - -/// Notifies all watchers of a change -fn notify_all(shared: &Shared) { - let watchers = shared.watchers.lock().unwrap(); - for watcher in watchers.iter() { - // Notify the task - watcher.waker.wake(); + notified.await; + debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed)); } } impl Drop for Sender { fn drop(&mut self) { - if let Some(shared) = self.shared.upgrade() { - shared.version.fetch_or(CLOSED, SeqCst); - notify_all(&*shared); - } + self.shared.version.fetch_or(CLOSED, SeqCst); + self.shared.notify_rx.notify_waiters(); } } @@ -390,44 +388,3 @@ impl ops::Deref for Ref<'_, T> { self.inner.deref() } } - -// ===== impl Shared ===== - -impl Drop for Shared { - fn drop(&mut self) { - self.cancel.wake(); - } -} - -// ===== impl Watcher ===== - -impl Watcher { - fn new_version(version: usize) -> Self { - Watcher(Arc::new(WatchInner { - version: AtomicUsize::new(version), - waker: AtomicWaker::new(), - })) - } -} - -impl std::cmp::PartialEq for Watcher { - fn eq(&self, other: &Watcher) -> bool { - Arc::ptr_eq(&self.0, &other.0) - } -} - -impl std::cmp::Eq for Watcher {} - -impl std::hash::Hash for Watcher { - fn hash(&self, state: &mut H) { - (&*self.0 as *const WatchInner).hash(state) - } -} - -impl std::ops::Deref for Watcher { - type Target = WatchInner; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} -- cgit v1.2.3