summaryrefslogtreecommitdiffstats
path: root/tokio/src/task
diff options
context:
space:
mode:
Diffstat (limited to 'tokio/src/task')
-rw-r--r--tokio/src/task/mod.rs3
-rw-r--r--tokio/src/task/task_local.rs240
2 files changed, 243 insertions, 0 deletions
diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs
index f762a561..efeb5f0e 100644
--- a/tokio/src/task/mod.rs
+++ b/tokio/src/task/mod.rs
@@ -257,6 +257,9 @@ cfg_rt_core! {
cfg_rt_util! {
mod local;
pub use local::{spawn_local, LocalSet};
+
+ mod task_local;
+ pub use task_local::LocalKey;
}
cfg_rt_core! {
diff --git a/tokio/src/task/task_local.rs b/tokio/src/task/task_local.rs
new file mode 100644
index 00000000..6423175b
--- /dev/null
+++ b/tokio/src/task/task_local.rs
@@ -0,0 +1,240 @@
+use pin_project_lite::pin_project;
+use std::cell::RefCell;
+use std::error::Error;
+use std::future::Future;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use std::{fmt, thread};
+
+/// Declare a new task local storage key of type [`tokio::task::LocalKey`].
+///
+/// # Syntax
+///
+/// The macro wraps any number of static declarations and makes them task locals.
+/// Publicity and attributes for each static are allowed.
+///
+/// # Examples
+///
+/// ```
+/// # use tokio::task_local;
+/// task_local! {
+/// pub static FOO: u32;
+///
+/// #[allow(unused)]
+/// static BAR: f32;
+/// }
+/// # fn main() {}
+/// ```
+///
+/// See [LocalKey documentation][`tokio::task::LocalKey`] for more
+/// information.
+///
+/// [`tokio::task::LocalKey`]: ../tokio/task/struct.LocalKey.html
+#[macro_export]
+macro_rules! task_local {
+ // empty (base case for the recursion)
+ () => {};
+
+ ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
+ $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
+ $crate::task_local!($($rest)*);
+ };
+
+ ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
+ $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
+ }
+}
+
+#[doc(hidden)]
+#[macro_export]
+macro_rules! __task_local_inner {
+ ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
+ static $name: $crate::task::LocalKey<$t> = {
+ std::thread_local! {
+ static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None);
+ }
+
+ $crate::task::LocalKey { inner: __KEY }
+ };
+ };
+}
+
+/// A key for task-local data.
+///
+/// This type is generated by `task_local!` macro and unlike `thread_local!` it has
+/// no concept of lazily initialization. Instead, it is designed to provide task local
+/// storage the future that is passed to `set`.
+///
+/// # Initialization and Destruction
+///
+/// Initialization is done via `set` which is an `async fn` that wraps another
+/// [`std::future::Future`] and will set the value on each `Future::poll` call.
+/// Once the `set` future is dropped the corresponding task local value is also
+/// dropped.
+///
+/// # Examples
+///
+/// ```
+/// # async fn dox() {
+/// tokio::task_local! {
+/// static FOO: u32;
+/// }
+///
+/// FOO.scope(1, async move {
+/// assert_eq!(FOO.get(), 1);
+/// }).await;
+///
+/// FOO.scope(2, async move {
+/// assert_eq!(FOO.get(), 2);
+///
+/// FOO.scope(3, async move {
+/// assert_eq!(FOO.get(), 3);
+/// }).await;
+/// }).await;
+/// # }
+/// ```
+pub struct LocalKey<T: 'static> {
+ #[doc(hidden)]
+ pub inner: thread::LocalKey<RefCell<Option<T>>>,
+}
+
+impl<T: 'static> LocalKey<T> {
+ /// Sets a value `T` as the task local value for the future `F`.
+ ///
+ /// This will run the provided future to completion and set the
+ /// provided value as the task local under this key. Once the returned
+ /// future is dropped so will the value passed be dropped.
+ ///
+ /// # async fn dox() {
+ /// tokio::task_local! {
+ /// static FOO: u32;
+ /// }
+ ///
+ /// FOO.scope(1, async move {
+ /// println!("task local value: {}", FOO.get());
+ /// }).await;
+ /// # }
+ pub async fn scope<F>(&'static self, value: T, f: F) -> F::Output
+ where
+ F: Future,
+ {
+ TaskLocalFuture {
+ local: &self,
+ slot: Some(value),
+ future: f,
+ }
+ .await
+ }
+
+ /// Access this task-local key, running the provided closure with a reference
+ /// passed to the value.
+ ///
+ /// # Panics
+ ///
+ /// This function will panic if not called within a future that has not been
+ /// set via `LocalKey::set`.
+ pub fn with<F, R>(&'static self, f: F) -> R
+ where
+ F: FnOnce(&T) -> R,
+ {
+ self.try_with(f).expect(
+ "cannot access a Task Local Storage value \
+ without setting it via `LocalKey::set`",
+ )
+ }
+
+ /// Access this task-local key, running the provided closure with a reference
+ /// passed to the value. Unlike `with` this function will return a `Result<R, AccessError>`
+ /// instead of panicking.
+ pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
+ where
+ F: FnOnce(&T) -> R,
+ {
+ self.inner.with(|v| {
+ if let Some(val) = v.borrow().as_ref() {
+ Ok(f(val))
+ } else {
+ Err(AccessError { _private: () })
+ }
+ })
+ }
+}
+
+impl<T: Copy + 'static> LocalKey<T> {
+ /// Get a copy of the task-local value if it implements
+ /// the `Copy` trait.
+ pub fn get(&'static self) -> T {
+ self.with(|v| *v)
+ }
+}
+
+impl<T: 'static> fmt::Debug for LocalKey<T> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.pad("LocalKey { .. }")
+ }
+}
+
+pin_project! {
+ struct TaskLocalFuture<T: StaticLifetime, F> {
+ local: &'static LocalKey<T>,
+ slot: Option<T>,
+ #[pin]
+ future: F,
+ }
+}
+
+impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
+ type Output = F::Output;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ struct Guard<'a, T: 'static> {
+ local: &'static LocalKey<T>,
+ slot: &'a mut Option<T>,
+ prev: Option<T>,
+ }
+
+ impl<T> Drop for Guard<'_, T> {
+ fn drop(&mut self) {
+ let value = self.local.inner.with(|c| c.replace(self.prev.take()));
+ *self.slot = value;
+ }
+ }
+
+ let mut project = self.project();
+ let val = project.slot.take();
+
+ let prev = project.local.inner.with(|c| c.replace(val));
+
+ let _guard = Guard {
+ prev,
+ slot: &mut project.slot,
+ local: *project.local,
+ };
+
+ project.future.poll(cx)
+ }
+}
+
+// Required to make `pin_project` happy.
+trait StaticLifetime: 'static {}
+impl<T: 'static> StaticLifetime for T {}
+
+/// An error returned by [`LocalKey::try_with`](struct.LocalKey.html#method.try_with).
+#[derive(Clone, Copy, Eq, PartialEq)]
+pub struct AccessError {
+ _private: (),
+}
+
+impl fmt::Debug for AccessError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("AccessError").finish()
+ }
+}
+
+impl fmt::Display for AccessError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Display::fmt("task-local value not set", f)
+ }
+}
+
+impl Error for AccessError {}