diff options
Diffstat (limited to 'tokio/src/task')
-rw-r--r-- | tokio/src/task/mod.rs | 3 | ||||
-rw-r--r-- | tokio/src/task/task_local.rs | 240 |
2 files changed, 243 insertions, 0 deletions
diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index f762a561..efeb5f0e 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -257,6 +257,9 @@ cfg_rt_core! { cfg_rt_util! { mod local; pub use local::{spawn_local, LocalSet}; + + mod task_local; + pub use task_local::LocalKey; } cfg_rt_core! { diff --git a/tokio/src/task/task_local.rs b/tokio/src/task/task_local.rs new file mode 100644 index 00000000..6423175b --- /dev/null +++ b/tokio/src/task/task_local.rs @@ -0,0 +1,240 @@ +use pin_project_lite::pin_project; +use std::cell::RefCell; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, thread}; + +/// Declare a new task local storage key of type [`tokio::task::LocalKey`]. +/// +/// # Syntax +/// +/// The macro wraps any number of static declarations and makes them task locals. +/// Publicity and attributes for each static are allowed. +/// +/// # Examples +/// +/// ``` +/// # use tokio::task_local; +/// task_local! { +/// pub static FOO: u32; +/// +/// #[allow(unused)] +/// static BAR: f32; +/// } +/// # fn main() {} +/// ``` +/// +/// See [LocalKey documentation][`tokio::task::LocalKey`] for more +/// information. +/// +/// [`tokio::task::LocalKey`]: ../tokio/task/struct.LocalKey.html +#[macro_export] +macro_rules! task_local { + // empty (base case for the recursion) + () => {}; + + ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => { + $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); + $crate::task_local!($($rest)*); + }; + + ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => { + $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __task_local_inner { + ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { + static $name: $crate::task::LocalKey<$t> = { + std::thread_local! { + static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None); + } + + $crate::task::LocalKey { inner: __KEY } + }; + }; +} + +/// A key for task-local data. +/// +/// This type is generated by `task_local!` macro and unlike `thread_local!` it has +/// no concept of lazily initialization. Instead, it is designed to provide task local +/// storage the future that is passed to `set`. +/// +/// # Initialization and Destruction +/// +/// Initialization is done via `set` which is an `async fn` that wraps another +/// [`std::future::Future`] and will set the value on each `Future::poll` call. +/// Once the `set` future is dropped the corresponding task local value is also +/// dropped. +/// +/// # Examples +/// +/// ``` +/// # async fn dox() { +/// tokio::task_local! { +/// static FOO: u32; +/// } +/// +/// FOO.scope(1, async move { +/// assert_eq!(FOO.get(), 1); +/// }).await; +/// +/// FOO.scope(2, async move { +/// assert_eq!(FOO.get(), 2); +/// +/// FOO.scope(3, async move { +/// assert_eq!(FOO.get(), 3); +/// }).await; +/// }).await; +/// # } +/// ``` +pub struct LocalKey<T: 'static> { + #[doc(hidden)] + pub inner: thread::LocalKey<RefCell<Option<T>>>, +} + +impl<T: 'static> LocalKey<T> { + /// Sets a value `T` as the task local value for the future `F`. + /// + /// This will run the provided future to completion and set the + /// provided value as the task local under this key. Once the returned + /// future is dropped so will the value passed be dropped. + /// + /// # async fn dox() { + /// tokio::task_local! { + /// static FOO: u32; + /// } + /// + /// FOO.scope(1, async move { + /// println!("task local value: {}", FOO.get()); + /// }).await; + /// # } + pub async fn scope<F>(&'static self, value: T, f: F) -> F::Output + where + F: Future, + { + TaskLocalFuture { + local: &self, + slot: Some(value), + future: f, + } + .await + } + + /// Access this task-local key, running the provided closure with a reference + /// passed to the value. + /// + /// # Panics + /// + /// This function will panic if not called within a future that has not been + /// set via `LocalKey::set`. + pub fn with<F, R>(&'static self, f: F) -> R + where + F: FnOnce(&T) -> R, + { + self.try_with(f).expect( + "cannot access a Task Local Storage value \ + without setting it via `LocalKey::set`", + ) + } + + /// Access this task-local key, running the provided closure with a reference + /// passed to the value. Unlike `with` this function will return a `Result<R, AccessError>` + /// instead of panicking. + pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> + where + F: FnOnce(&T) -> R, + { + self.inner.with(|v| { + if let Some(val) = v.borrow().as_ref() { + Ok(f(val)) + } else { + Err(AccessError { _private: () }) + } + }) + } +} + +impl<T: Copy + 'static> LocalKey<T> { + /// Get a copy of the task-local value if it implements + /// the `Copy` trait. + pub fn get(&'static self) -> T { + self.with(|v| *v) + } +} + +impl<T: 'static> fmt::Debug for LocalKey<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("LocalKey { .. }") + } +} + +pin_project! { + struct TaskLocalFuture<T: StaticLifetime, F> { + local: &'static LocalKey<T>, + slot: Option<T>, + #[pin] + future: F, + } +} + +impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + struct Guard<'a, T: 'static> { + local: &'static LocalKey<T>, + slot: &'a mut Option<T>, + prev: Option<T>, + } + + impl<T> Drop for Guard<'_, T> { + fn drop(&mut self) { + let value = self.local.inner.with(|c| c.replace(self.prev.take())); + *self.slot = value; + } + } + + let mut project = self.project(); + let val = project.slot.take(); + + let prev = project.local.inner.with(|c| c.replace(val)); + + let _guard = Guard { + prev, + slot: &mut project.slot, + local: *project.local, + }; + + project.future.poll(cx) + } +} + +// Required to make `pin_project` happy. +trait StaticLifetime: 'static {} +impl<T: 'static> StaticLifetime for T {} + +/// An error returned by [`LocalKey::try_with`](struct.LocalKey.html#method.try_with). +#[derive(Clone, Copy, Eq, PartialEq)] +pub struct AccessError { + _private: (), +} + +impl fmt::Debug for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AccessError").finish() + } +} + +impl fmt::Display for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt("task-local value not set", f) + } +} + +impl Error for AccessError {} |