use proc_macro::TokenStream; use proc_macro2::Span; use quote::quote; use syn::spanned::Spanned; #[derive(Clone, Copy, PartialEq)] enum RuntimeFlavor { CurrentThread, Threaded, } impl RuntimeFlavor { fn from_str(s: &str) -> Result { match s { "current_thread" => Ok(RuntimeFlavor::CurrentThread), "multi_thread" => Ok(RuntimeFlavor::Threaded), "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()), "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()), "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()), _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.", s)), } } } struct FinalConfig { flavor: RuntimeFlavor, worker_threads: Option, } struct Configuration { rt_multi_thread_available: bool, default_flavor: RuntimeFlavor, flavor: Option, worker_threads: Option<(usize, Span)>, } impl Configuration { fn new(is_test: bool, rt_multi_thread: bool) -> Self { Configuration { rt_multi_thread_available: rt_multi_thread, default_flavor: match is_test { true => RuntimeFlavor::CurrentThread, false => RuntimeFlavor::Threaded, }, flavor: None, worker_threads: None, } } fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> { if self.flavor.is_some() { return Err(syn::Error::new(span, "`flavor` set multiple times.")); } let runtime_str = parse_string(runtime, span, "flavor")?; let runtime = RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?; self.flavor = Some(runtime); Ok(()) } fn set_worker_threads( &mut self, worker_threads: syn::Lit, span: Span, ) -> Result<(), syn::Error> { if self.worker_threads.is_some() { return Err(syn::Error::new( span, "`worker_threads` set multiple times.", )); } let worker_threads = parse_int(worker_threads, span, "worker_threads")?; if worker_threads == 0 { return Err(syn::Error::new(span, "`worker_threads` may not be 0.")); } self.worker_threads = Some((worker_threads, span)); Ok(()) } fn build(&self) -> Result { let flavor = self.flavor.unwrap_or(self.default_flavor); use RuntimeFlavor::*; match (flavor, self.worker_threads) { (CurrentThread, Some((_, worker_threads_span))) => Err(syn::Error::new( worker_threads_span, "The `worker_threads` option requires the `multi_thread` runtime flavor.", )), (CurrentThread, None) => Ok(FinalConfig { flavor, worker_threads: None, }), (Threaded, worker_threads) if self.rt_multi_thread_available => Ok(FinalConfig { flavor, worker_threads: worker_threads.map(|(val, _span)| val), }), (Threaded, _) => { let msg = if self.flavor.is_none() { "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled." } else { "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature." }; Err(syn::Error::new(Span::call_site(), msg)) } } } } fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result { match int { syn::Lit::Int(lit) => match lit.base10_parse::() { Ok(value) => Ok(value), Err(e) => Err(syn::Error::new( span, format!("Failed to parse {} as integer: {}", field, e), )), }, _ => Err(syn::Error::new( span, format!("Failed to parse {} as integer.", field), )), } } fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result { match int { syn::Lit::Str(s) => Ok(s.value()), syn::Lit::Verbatim(s) => Ok(s.to_string()), _ => Err(syn::Error::new( span, format!("Failed to parse {} as string.", field), )), } } fn parse_knobs( mut input: syn::ItemFn, args: syn::AttributeArgs, is_test: bool, rt_multi_thread: bool, ) -> Result { let sig = &mut input.sig; let body = &input.block; let attrs = &input.attrs; let vis = input.vis; if sig.asyncness.is_none() { let msg = "the async keyword is missing from the function declaration"; return Err(syn::Error::new_spanned(sig.fn_token, msg)); } sig.asyncness = None; let macro_name = if is_test { "tokio::test" } else { "tokio::main" }; let mut config = Configuration::new(is_test, rt_multi_thread); for arg in args { match arg { syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => { let ident = namevalue.path.get_ident(); if ident.is_none() { let msg = "Must have specified ident"; return Err(syn::Error::new_spanned(namevalue, msg)); } match ident.unwrap().to_string().to_lowercase().as_str() { "worker_threads" => { config.set_worker_threads(namevalue.lit.clone(), namevalue.span())?; } "flavor" => { config.set_flavor(namevalue.lit.clone(), namevalue.span())?; } "core_threads" => { let msg = "Attribute `core_threads` is renamed to `worker_threads`"; return Err(syn::Error::new_spanned(namevalue, msg)); } name => { let msg = format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`", name); return Err(syn::Error::new_spanned(namevalue, msg)); } } } syn::NestedMeta::Meta(syn::Meta::Path(path)) => { let ident = path.get_ident(); if ident.is_none() { let msg = "Must have specified ident"; return Err(syn::Error::new_spanned(path, msg)); } let name = ident.unwrap().to_string().to_lowercase(); let msg = match name.as_str() { "threaded_scheduler" | "multi_thread" => { format!( "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].", macro_name ) } "basic_scheduler" | "current_thread" | "single_threaded" => { format!( "Set the runtime flavor with #[{}(flavor = \"current_thread\")].", macro_name ) } "flavor" | "worker_threads" => { format!("The `{}` attribute requires an argument.", name) } name => { format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`", name) } }; return Err(syn::Error::new_spanned(path, msg)); } other => { return Err(syn::Error::new_spanned( other, "Unknown attribute inside the macro", )); } } } let config = config.build()?; let mut rt = match config.flavor { RuntimeFlavor::CurrentThread => quote! { tokio::runtime::Builder::new_current_thread() }, RuntimeFlavor::Threaded => quote! { tokio::runtime::Builder::new_multi_thread() }, }; if let Some(v) = config.worker_threads { rt = quote! { #rt.worker_threads(#v) }; } let header = { if is_test { quote! { #[::core::prelude::v1::test] } } else { quote! {} } }; let result = quote! { #header #(#attrs)* #vis #sig { #rt .enable_all() .build() .unwrap() .block_on(async { #body }) } }; Ok(result.into()) } #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { let input = syn::parse_macro_input!(item as syn::ItemFn); let args = syn::parse_macro_input!(args as syn::AttributeArgs); if input.sig.ident == "main" && !input.sig.inputs.is_empty() { let msg = "the main function cannot accept arguments"; return syn::Error::new_spanned(&input.sig.ident, msg) .to_compile_error() .into(); } parse_knobs(input, args, false, rt_multi_thread).unwrap_or_else(|e| e.to_compile_error().into()) } pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { let input = syn::parse_macro_input!(item as syn::ItemFn); let args = syn::parse_macro_input!(args as syn::AttributeArgs); for attr in &input.attrs { if attr.path.is_ident("test") { let msg = "second test attribute is supplied"; return syn::Error::new_spanned(&attr, msg) .to_compile_error() .into(); } } if !input.sig.inputs.is_empty() { let msg = "the test function cannot accept arguments"; return syn::Error::new_spanned(&input.sig.inputs, msg) .to_compile_error() .into(); } parse_knobs(input, args, true, rt_multi_thread).unwrap_or_else(|e| e.to_compile_error().into()) }