//! Futures task based helpers #![allow(clippy::mutex_atomic)] use std::future::Future; use std::mem; use std::ops; use std::pin::Pin; use std::sync::{Arc, Condvar, Mutex}; use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; use tokio::stream::Stream; /// TODO: dox pub fn spawn(task: T) -> Spawn { Spawn { task: MockTask::new(), future: Box::pin(task), } } /// Future spawned on a mock task #[derive(Debug)] pub struct Spawn { task: MockTask, future: Pin>, } /// Mock task /// /// A mock task is able to intercept and track wake notifications. #[derive(Debug, Clone)] struct MockTask { waker: Arc, } #[derive(Debug)] struct ThreadWaker { state: Mutex, condvar: Condvar, } const IDLE: usize = 0; const WAKE: usize = 1; const SLEEP: usize = 2; impl Spawn { /// Consumes `self` returning the inner value pub fn into_inner(self) -> T where T: Unpin, { *Pin::into_inner(self.future) } /// Returns `true` if the inner future has received a wake notification /// since the last call to `enter`. pub fn is_woken(&self) -> bool { self.task.is_woken() } /// Returns the number of references to the task waker /// /// The task itself holds a reference. The return value will never be zero. pub fn waker_ref_count(&self) -> usize { self.task.waker_ref_count() } /// Enter the task context pub fn enter(&mut self, f: F) -> R where F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R, { let fut = self.future.as_mut(); self.task.enter(|cx| f(cx, fut)) } } impl ops::Deref for Spawn { type Target = T; fn deref(&self) -> &T { &self.future } } impl ops::DerefMut for Spawn { fn deref_mut(&mut self) -> &mut T { &mut self.future } } impl Spawn { /// Polls a future pub fn poll(&mut self) -> Poll { let fut = self.future.as_mut(); self.task.enter(|cx| fut.poll(cx)) } } impl Spawn { /// Polls a stream pub fn poll_next(&mut self) -> Poll> { let stream = self.future.as_mut(); self.task.enter(|cx| stream.poll_next(cx)) } } impl Future for Spawn { type Output = T::Output; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.future.as_mut().poll(cx) } } impl Stream for Spawn { type Item = T::Item; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.future.as_mut().poll_next(cx) } } impl MockTask { /// Creates new mock task fn new() -> Self { MockTask { waker: Arc::new(ThreadWaker::new()), } } /// Runs a closure from the context of the task. /// /// Any wake notifications resulting from the execution of the closure are /// tracked. fn enter(&mut self, f: F) -> R where F: FnOnce(&mut Context<'_>) -> R, { self.waker.clear(); let waker = self.waker(); let mut cx = Context::from_waker(&waker); f(&mut cx) } /// Returns `true` if the inner future has received a wake notification /// since the last call to `enter`. fn is_woken(&self) -> bool { self.waker.is_woken() } /// Returns the number of references to the task waker /// /// The task itself holds a reference. The return value will never be zero. fn waker_ref_count(&self) -> usize { Arc::strong_count(&self.waker) } fn waker(&self) -> Waker { unsafe { let raw = to_raw(self.waker.clone()); Waker::from_raw(raw) } } } impl Default for MockTask { fn default() -> Self { Self::new() } } impl ThreadWaker { fn new() -> Self { ThreadWaker { state: Mutex::new(IDLE), condvar: Condvar::new(), } } /// Clears any previously received wakes, avoiding potential spurrious /// wake notifications. This should only be called immediately before running the /// task. fn clear(&self) { *self.state.lock().unwrap() = IDLE; } fn is_woken(&self) -> bool { match *self.state.lock().unwrap() { IDLE => false, WAKE => true, _ => unreachable!(), } } fn wake(&self) { // First, try transitioning from IDLE -> NOTIFY, this does not require a lock. let mut state = self.state.lock().unwrap(); let prev = *state; if prev == WAKE { return; } *state = WAKE; if prev == IDLE { return; } // The other half is sleeping, so we wake it up. assert_eq!(prev, SLEEP); self.condvar.notify_one(); } } static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker); unsafe fn to_raw(waker: Arc) -> RawWaker { RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE) } unsafe fn from_raw(raw: *const ()) -> Arc { Arc::from_raw(raw as *const ThreadWaker) } unsafe fn clone(raw: *const ()) -> RawWaker { let waker = from_raw(raw); // Increment the ref count mem::forget(waker.clone()); to_raw(waker) } unsafe fn wake(raw: *const ()) { let waker = from_raw(raw); waker.wake(); } unsafe fn wake_by_ref(raw: *const ()) { let waker = from_raw(raw); waker.wake(); // We don't actually own a reference to the unparker mem::forget(waker); } unsafe fn drop_waker(raw: *const ()) { let _ = from_raw(raw); }