From 2b909d6805990abf0bc2a5dea9e7267ff87df704 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Tue, 29 Oct 2019 15:11:31 -0700 Subject: sync: move into `tokio` crate (#1705) A step towards collapsing Tokio sub crates into a single `tokio` crate (#1318). The sync implementation is now provided by the main `tokio` crate. Functionality can be opted out of by using the various net related feature flags. --- tokio/src/sync/watch.rs | 454 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 454 insertions(+) create mode 100644 tokio/src/sync/watch.rs (limited to 'tokio/src/sync/watch.rs') diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs new file mode 100644 index 00000000..30f1603f --- /dev/null +++ b/tokio/src/sync/watch.rs @@ -0,0 +1,454 @@ +//! A single-producer, multi-consumer channel that only retains the *last* sent +//! value. +//! +//! This channel is useful for watching for changes to a value from multiple +//! points in the code base, for example, changes to configuration values. +//! +//! # Usage +//! +//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are +//! the producer and sender halves of the channel. The channel is +//! created with an initial value. [`Receiver::get_ref`] will always +//! be ready upon creation and will yield either this initial value or +//! the latest value that has been sent by `Sender`. +//! +//! Calls to [`Receiver::get_ref`] will always yield the latest value. +//! +//! # Examples +//! +//! ``` +//! use tokio::sync::watch; +//! +//! # async fn dox() -> Result<(), Box> { +//! let (tx, mut rx) = watch::channel("hello"); +//! +//! tokio::spawn(async move { +//! while let Some(value) = rx.recv().await { +//! println!("received = {:?}", value); +//! } +//! }); +//! +//! tx.broadcast("world")?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Closing +//! +//! [`Sender::closed`] allows the producer to detect when all [`Receiver`] +//! handles have been dropped. This indicates that there is no further interest +//! in the values being produced and work can be stopped. +//! +//! # Thread safety +//! +//! Both [`Sender`] and [`Receiver`] are thread safe. They can be moved to other +//! threads and can be used in a concurrent environment. Clones of [`Receiver`] +//! handles may be moved to separate threads and also used concurrently. +//! +//! [`Sender`]: struct.Sender.html +//! [`Receiver`]: struct.Receiver.html +//! [`channel`]: fn.channel.html +//! [`Sender::closed`]: struct.Sender.html#method.closed +//! [`Receiver::get_ref`]: struct.Receiver.html#method.get_ref + +use crate::sync::task::AtomicWaker; + +use core::task::Poll::{Pending, Ready}; +use core::task::{Context, Poll}; +use fnv::FnvHashMap; +use futures_util::future::poll_fn; +use std::ops; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak}; + +use futures_core::ready; +use futures_util::pin_mut; +use std::pin::Pin; + +/// Receives values from the associated [`Sender`](struct.Sender.html). +/// +/// Instances are created by the [`channel`](fn.channel.html) function. +#[derive(Debug)] +pub struct Receiver { + /// Pointer to the shared state + shared: Arc>, + + /// Pointer to the watcher's internal state + inner: Arc, + + /// Watcher ID. + id: u64, + + /// Last observed version + ver: usize, +} + +/// Sends values to the associated [`Receiver`](struct.Receiver.html). +/// +/// Instances are created by the [`channel`](fn.channel.html) function. +#[derive(Debug)] +pub struct Sender { + shared: Weak>, +} + +/// Returns a reference to the inner value +/// +/// Outstanding borrows hold a read lock on the inner value. This means that +/// long lived borrows could cause the produce half to block. It is recommended +/// to keep the borrow as short lived as possible. +#[derive(Debug)] +pub struct Ref<'a, T> { + inner: RwLockReadGuard<'a, T>, +} + +pub mod error { + //! Watch error types + + use std::fmt; + + /// Error produced when sending a value fails. + #[derive(Debug)] + pub struct SendError { + pub(crate) inner: T, + } + + // ===== impl SendError ===== + + impl fmt::Display for SendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl ::std::error::Error for SendError {} +} + +#[derive(Debug)] +struct Shared { + /// The most recent value + value: RwLock, + + /// The current version + /// + /// The lowest bit represents a "closed" state. The rest of the bits + /// represent the current version. + version: AtomicUsize, + + /// All watchers + watchers: Mutex, + + /// Task to notify when all watchers drop + cancel: AtomicWaker, +} + +#[derive(Debug)] +struct Watchers { + next_id: u64, + watchers: FnvHashMap>, +} + +#[derive(Debug)] +struct WatchInner { + waker: AtomicWaker, +} + +const CLOSED: usize = 1; + +/// Create a new watch channel, returning the "send" and "receive" handles. +/// +/// All values sent by [`Sender`] will become visible to the [`Receiver`] handles. +/// Only the last value sent is made available to the [`Receiver`] half. All +/// intermediate values are dropped. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::watch; +/// +/// # async fn dox() -> Result<(), Box> { +/// let (tx, mut rx) = watch::channel("hello"); +/// +/// tokio::spawn(async move { +/// while let Some(value) = rx.recv().await { +/// println!("received = {:?}", value); +/// } +/// }); +/// +/// tx.broadcast("world")?; +/// # Ok(()) +/// # } +/// ``` +/// +/// [`Sender`]: struct.Sender.html +/// [`Receiver`]: struct.Receiver.html +pub fn channel(init: T) -> (Sender, Receiver) { + const INIT_ID: u64 = 0; + + let inner = Arc::new(WatchInner::new()); + + // Insert the watcher + let mut watchers = FnvHashMap::with_capacity_and_hasher(0, Default::default()); + watchers.insert(INIT_ID, inner.clone()); + + let shared = Arc::new(Shared { + value: RwLock::new(init), + version: AtomicUsize::new(2), + watchers: Mutex::new(Watchers { + next_id: INIT_ID + 1, + watchers, + }), + cancel: AtomicWaker::new(), + }); + + let tx = Sender { + shared: Arc::downgrade(&shared), + }; + + let rx = Receiver { + shared, + inner, + id: INIT_ID, + ver: 0, + }; + + (tx, rx) +} + +impl Receiver { + /// Returns a reference to the most recently sent value + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (_, rx) = watch::channel("hello"); + /// assert_eq!(*rx.get_ref(), "hello"); + /// ``` + pub fn get_ref(&self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + Ref { inner } + } + + /// Attempts to receive the latest value sent via the channel. + /// + /// If a new, unobserved, value has been sent, a reference to it is + /// returned. If no new value has been sent, then `Pending` is returned and + /// the current task is notified once a new value is sent. + /// + /// Only the **most recent** value is returned. If the receiver is falling + /// behind the sender, intermediate values are dropped. + pub async fn recv_ref(&mut self) -> Option> { + let shared = &self.shared; + let inner = &self.inner; + let version = self.ver; + + match poll_fn(|cx| poll_lock(cx, shared, inner, version)).await { + Some((lock, version)) => { + self.ver = version; + Some(lock) + } + None => None, + } + } +} + +fn poll_lock<'a, T>( + cx: &mut Context<'_>, + shared: &'a Arc>, + inner: &Arc, + ver: usize, +) -> Poll, usize)>> { + // Make sure the task is up to date + inner.waker.register_by_ref(cx.waker()); + + let state = shared.version.load(SeqCst); + let version = state & !CLOSED; + + if version != ver { + let inner = shared.value.read().unwrap(); + + return Ready(Some((Ref { inner }, version))); + } + + if CLOSED == state & CLOSED { + // The `Store` handle has been dropped. + return Ready(None); + } + + Pending +} + +impl Receiver { + /// Attempts to clone the latest value sent via the channel. + /// + /// This is equivalent to calling `clone()` on the value returned by + /// `recv_ref()`. + #[allow(clippy::map_clone)] // false positive: https://github.com/rust-lang/rust-clippy/issues/3274 + pub async fn recv(&mut self) -> Option { + self.recv_ref().await.map(|v_ref| v_ref.clone()) + } +} + +impl futures_core::Stream for Receiver { + type Item = T; + + #[allow(clippy::map_clone)] // false positive: https://github.com/rust-lang/rust-clippy/issues/3274 + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use std::future::Future; + + let fut = self.get_mut().recv(); + pin_mut!(fut); + + let item = ready!(fut.poll(cx)); + Ready(item.map(|v_ref| v_ref.clone())) + } +} + +impl Clone for Receiver { + fn clone(&self) -> Self { + let inner = Arc::new(WatchInner::new()); + let shared = self.shared.clone(); + + let id = { + let mut watchers = shared.watchers.lock().unwrap(); + let id = watchers.next_id; + + watchers.next_id += 1; + watchers.watchers.insert(id, inner.clone()); + + id + }; + + let ver = self.ver; + + Receiver { + shared, + inner, + id, + ver, + } + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + let mut watchers = self.shared.watchers.lock().unwrap(); + watchers.watchers.remove(&self.id); + } +} + +impl WatchInner { + fn new() -> Self { + WatchInner { + waker: AtomicWaker::new(), + } + } +} + +impl Sender { + /// Broadcast a new value via the channel, notifying all receivers. + pub fn broadcast(&self, value: T) -> Result<(), error::SendError> { + let shared = match self.shared.upgrade() { + Some(shared) => shared, + // All `Watch` handles have been canceled + None => return Err(error::SendError { inner: value }), + }; + + // Replace the value + { + let mut lock = shared.value.write().unwrap(); + *lock = value; + } + + // Update the version. 2 is used so that the CLOSED bit is not set. + shared.version.fetch_add(2, SeqCst); + + // Notify all watchers + notify_all(&*shared); + + // Return the old value + Ok(()) + } + + /// Completes when all receivers have dropped. + /// + /// This allows the producer to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + pub async fn closed(&mut self) { + poll_fn(|cx| self.poll_close(cx)).await + } + + fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { + match self.shared.upgrade() { + Some(shared) => { + shared.cancel.register_by_ref(cx.waker()); + Pending + } + None => Ready(()), + } + } +} + +impl futures_sink::Sink for Sender { + type Error = error::SendError; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + self.as_ref().get_ref().broadcast(item)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ready(Ok(())) + } +} + +/// Notify all watchers of a change +fn notify_all(shared: &Shared) { + let watchers = shared.watchers.lock().unwrap(); + + for watcher in watchers.watchers.values() { + // Notify the task + watcher.waker.wake(); + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if let Some(shared) = self.shared.upgrade() { + shared.version.fetch_or(CLOSED, SeqCst); + notify_all(&*shared); + } + } +} + +// ===== impl Ref ===== + +impl ops::Deref for Ref<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + self.inner.deref() + } +} + +// ===== impl Shared ===== + +impl Drop for Shared { + fn drop(&mut self) { + self.cancel.wake(); + } +} -- cgit v1.2.3