diff options
author | Artem Vorotnikov <artem@vorotnikov.me> | 2019-12-27 21:56:43 +0300 |
---|---|---|
committer | Lucio Franco <luciofranco14@gmail.com> | 2019-12-27 13:56:43 -0500 |
commit | e8fcf55881f0d97177d7f5b2d5c0b00704c26fbe (patch) | |
tree | 4f9b6fda2599a9327814dbcb9af8d52a7f2fe1b6 /tokio-macros | |
parent | a515f9c459d662b9c93d962812dc1fd8d1b32e08 (diff) |
Refactor proc macros, add more knobs (#2022)
* Refactor proc macros, add more knobs
* make macros work with rt-core
Diffstat (limited to 'tokio-macros')
-rw-r--r-- | tokio-macros/src/lib.rs | 355 |
1 files changed, 235 insertions, 120 deletions
diff --git a/tokio-macros/src/lib.rs b/tokio-macros/src/lib.rs index 1116735a..fcdba2bf 100644 --- a/tokio-macros/src/lib.rs +++ b/tokio-macros/src/lib.rs @@ -18,19 +18,186 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; +use std::num::NonZeroUsize; enum Runtime { Basic, Threaded, - Auto, +} + +fn parse_knobs( + input: syn::ItemFn, + args: syn::AttributeArgs, + is_test: bool, + rt_threaded: bool, +) -> Result<TokenStream, syn::Error> { + let ret = &input.sig.output; + let name = &input.sig.ident; + let inputs = &input.sig.inputs; + let body = &input.block; + let attrs = &input.attrs; + let vis = input.vis; + + if input.sig.asyncness.is_none() { + let msg = "the async keyword is missing from the function declaration"; + return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); + } + + let mut runtime = None; + let mut core_threads = None; + let mut max_threads = None; + + for arg in args { + match arg { + syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => { + let ident = namevalue.path.get_ident(); + if ident.is_none() { + let msg = "Must have specified ident"; + return Err(syn::Error::new_spanned(namevalue, msg)); + } + match ident.unwrap().to_string().to_lowercase().as_str() { + "core_threads" => { + if rt_threaded { + match &namevalue.lit { + syn::Lit::Int(expr) => { + let num = expr.base10_parse::<NonZeroUsize>().unwrap(); + if num.get() > 1 { + runtime = Some(Runtime::Threaded); + } else { + runtime = Some(Runtime::Basic); + } + + if let Some(v) = max_threads { + if v < num { + return Err(syn::Error::new_spanned( + namevalue, + "max_threads cannot be less than core_threads", + )); + } + } + + core_threads = Some(num); + } + _ => { + return Err(syn::Error::new_spanned( + namevalue, + "core_threads argument must be an int", + )) + } + } + } else { + return Err(syn::Error::new_spanned( + namevalue, + "core_threads can only be set with rt-threaded feature flag enabled", + )); + } + } + "max_threads" => match &namevalue.lit { + syn::Lit::Int(expr) => { + let num = expr.base10_parse::<NonZeroUsize>().unwrap(); + + if let Some(v) = core_threads { + if num < v { + return Err(syn::Error::new_spanned( + namevalue, + "max_threads cannot be less than core_threads", + )); + } + } + max_threads = Some(num); + } + _ => { + return Err(syn::Error::new_spanned( + namevalue, + "max_threads argument must be an int", + )) + } + }, + name => { + let msg = format!("Unknown attribute pair {} is specified; expected one of: `core_threads`, `max_threads`", name); + return Err(syn::Error::new_spanned(namevalue, msg)); + } + } + } + syn::NestedMeta::Meta(syn::Meta::Path(path)) => { + let ident = path.get_ident(); + if ident.is_none() { + let msg = "Must have specified ident"; + return Err(syn::Error::new_spanned(path, msg)); + } + match ident.unwrap().to_string().to_lowercase().as_str() { + "threaded_scheduler" => { + runtime = Some(runtime.unwrap_or_else(|| Runtime::Threaded)) + } + "basic_scheduler" => runtime = Some(runtime.unwrap_or_else(|| Runtime::Basic)), + name => { + let msg = format!("Unknown attribute {} is specified; expected `basic_scheduler` or `threaded_scheduler`", name); + return Err(syn::Error::new_spanned(path, msg)); + } + } + } + other => { + return Err(syn::Error::new_spanned( + other, + "Unknown attribute inside the macro", + )); + } + } + } + + let mut rt = quote! { tokio::runtime::Builder::new() }; + match (runtime, is_test) { + (Some(Runtime::Threaded), _) => { + if rt_threaded { + rt = quote! { #rt.threaded_scheduler() } + } + } + (Some(Runtime::Basic), _) => rt = quote! { #rt.basic_scheduler() }, + (None, true) => rt = quote! { #rt.basic_scheduler() }, + (None, false) => { + if rt_threaded { + rt = quote! { #rt.threaded_scheduler() } + } + } + }; + if let Some(v) = core_threads.map(|v| v.get()) { + rt = quote! { #rt.core_threads(#v) }; + } + if let Some(v) = max_threads.map(|v| v.get()) { + rt = quote! { #rt.max_threads(#v) }; + } + + let header = { + if is_test { + quote! { + #[test] + } + } else { + quote! {} + } + }; + + let result = quote! { + #header + #(#attrs)* + #vis fn #name(#inputs) #ret { + #rt + .enable_all() + .build() + .unwrap() + .block_on(async { #body }) + } + }; + + Ok(result.into()) } /// Marks async function to be executed by selected runtime. /// /// ## Options: /// -/// - `basic_scheduler` - All tasks are executed on the current thread. -/// - `threaded_scheduler` - Uses the multi-threaded scheduler. Used by default. +/// - `core_threads=n` - Sets core threads to `n`. +/// - `max_threads=n` - Sets max threads to `n`. /// /// ## Function arguments: /// @@ -47,95 +214,73 @@ enum Runtime { /// } /// ``` /// -/// ### Select runtime +/// ### Set number of core threads +/// +/// ```rust +/// #[tokio::main(core_threads = 1)] +/// async fn main() { +/// println!("Hello world"); +/// } +/// ``` +#[proc_macro_attribute] +#[cfg(not(test))] // Work around for rust-lang/rust#62127 +pub fn main_threaded(args: TokenStream, item: TokenStream) -> TokenStream { + main(args, item, true) +} + +/// Marks async function to be executed by selected runtime. +/// +/// ## Options: +/// +/// - `max_threads=n` - Sets max threads to `n`. +/// +/// ## Function arguments: +/// +/// Arguments are allowed for any functions aside from `main` which is special +/// +/// ## Usage +/// +/// ### Using default /// /// ```rust -/// #[tokio::main(basic_scheduler)] +/// #[tokio::main] /// async fn main() { /// println!("Hello world"); /// } /// ``` #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 -pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { +pub fn main_basic(args: TokenStream, item: TokenStream) -> TokenStream { + main(args, item, false) +} + +fn main(args: TokenStream, item: TokenStream, rt_threaded: bool) -> TokenStream { let input = syn::parse_macro_input!(item as syn::ItemFn); let args = syn::parse_macro_input!(args as syn::AttributeArgs); - let ret = &input.sig.output; - let name = &input.sig.ident; - let inputs = &input.sig.inputs; - let body = &input.block; - let attrs = &input.attrs; - let vis = input.vis; - - if input.sig.asyncness.is_none() { - let msg = "the async keyword is missing from the function declaration"; - return syn::Error::new_spanned(input.sig.fn_token, msg) - .to_compile_error() - .into(); - } else if name == "main" && !inputs.is_empty() { + if input.sig.ident == "main" && !input.sig.inputs.is_empty() { let msg = "the main function cannot accept arguments"; return syn::Error::new_spanned(&input.sig.inputs, msg) .to_compile_error() .into(); } - let mut runtime = Runtime::Auto; - - for arg in args { - if let syn::NestedMeta::Meta(syn::Meta::Path(path)) = arg { - let ident = path.get_ident(); - if ident.is_none() { - let msg = "Must have specified ident"; - return syn::Error::new_spanned(path, msg).to_compile_error().into(); - } - match ident.unwrap().to_string().to_lowercase().as_str() { - "threaded_scheduler" => runtime = Runtime::Threaded, - "basic_scheduler" => runtime = Runtime::Basic, - name => { - let msg = format!("Unknown attribute {} is specified; expected `basic_scheduler` or `threaded_scheduler`", name); - return syn::Error::new_spanned(path, msg).to_compile_error().into(); - } - } - } - } - - let result = match runtime { - Runtime::Threaded | Runtime::Auto => quote! { - #(#attrs)* - #vis fn #name(#inputs) #ret { - tokio::runtime::Runtime::new().unwrap().block_on(async { #body }) - } - }, - Runtime::Basic => quote! { - #(#attrs)* - #vis fn #name(#inputs) #ret { - tokio::runtime::Builder::new() - .basic_scheduler() - .enable_all() - .build() - .unwrap() - .block_on(async { #body }) - } - }, - }; - - result.into() + parse_knobs(input, args, false, rt_threaded).unwrap_or_else(|e| e.to_compile_error().into()) } /// Marks async function to be executed by runtime, suitable to test enviornment /// /// ## Options: /// -/// - `basic_scheduler` - All tasks are executed on the current thread. Used by default. -/// - `threaded_scheduler` - Use multi-threaded scheduler. +/// - `core_threads=n` - Sets core threads to `n`. +/// - `max_threads=n` - Sets max threads to `n`. /// /// ## Usage /// /// ### Select runtime /// /// ```no_run -/// #[tokio::test(threaded_scheduler)] +/// #[tokio::test(core_threads = 1)] /// async fn my_test() { /// assert!(true); /// } @@ -150,17 +295,34 @@ pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { /// } /// ``` #[proc_macro_attribute] -pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { +pub fn test_threaded(args: TokenStream, item: TokenStream) -> TokenStream { + test(args, item, true) +} + +/// Marks async function to be executed by runtime, suitable to test enviornment +/// +/// ## Options: +/// +/// - `max_threads=n` - Sets max threads to `n`. +/// +/// ## Usage +/// +/// ```no_run +/// #[tokio::test] +/// async fn my_test() { +/// assert!(true); +/// } +/// ``` +#[proc_macro_attribute] +pub fn test_basic(args: TokenStream, item: TokenStream) -> TokenStream { + test(args, item, false) +} + +fn test(args: TokenStream, item: TokenStream, rt_threaded: bool) -> TokenStream { let input = syn::parse_macro_input!(item as syn::ItemFn); let args = syn::parse_macro_input!(args as syn::AttributeArgs); - let ret = &input.sig.output; - let name = &input.sig.ident; - let body = &input.block; - let attrs = &input.attrs; - let vis = input.vis; - - for attr in attrs { + for attr in &input.attrs { if attr.path.is_ident("test") { let msg = "second test attribute is supplied"; return syn::Error::new_spanned(&attr, msg) @@ -169,59 +331,12 @@ pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { } } - if input.sig.asyncness.is_none() { - let msg = "the async keyword is missing from the function declaration"; - return syn::Error::new_spanned(&input.sig.fn_token, msg) - .to_compile_error() - .into(); - } else if !input.sig.inputs.is_empty() { + if !input.sig.inputs.is_empty() { let msg = "the test function cannot accept arguments"; return syn::Error::new_spanned(&input.sig.inputs, msg) .to_compile_error() .into(); } - let mut runtime = Runtime::Auto; - - for arg in args { - if let syn::NestedMeta::Meta(syn::Meta::Path(path)) = arg { - let ident = path.get_ident(); - if ident.is_none() { - let msg = "Must have specified ident"; - return syn::Error::new_spanned(path, msg).to_compile_error().into(); - } - match ident.unwrap().to_string().to_lowercase().as_str() { - "threaded_scheduler" => runtime = Runtime::Threaded, - "basic_scheduler" => runtime = Runtime::Basic, - name => { - let msg = format!("Unknown attribute {} is specified; expected `basic_scheduler` or `threaded_scheduler`", name); - return syn::Error::new_spanned(path, msg).to_compile_error().into(); - } - } - } - } - - let result = match runtime { - Runtime::Threaded => quote! { - #[test] - #(#attrs)* - #vis fn #name() #ret { - tokio::runtime::Runtime::new().unwrap().block_on(async { #body }) - } - }, - Runtime::Basic | Runtime::Auto => quote! { - #[test] - #(#attrs)* - #vis fn #name() #ret { - tokio::runtime::Builder::new() - .basic_scheduler() - .enable_all() - .build() - .unwrap() - .block_on(async { #body }) - } - }, - }; - - result.into() + parse_knobs(input, args, true, rt_threaded).unwrap_or_else(|e| e.to_compile_error().into()) } |