diff options
Diffstat (limited to 'cuda_base/src')
-rw-r--r-- | cuda_base/src/lib.rs | 69 |
1 files changed, 39 insertions, 30 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 8b804d1..3f6f779 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -9,12 +9,11 @@ use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; -use syn::token::Brace; use syn::visit_mut::VisitMut; use syn::{ bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident, - Item, ItemForeignMod, ItemMacro, LitStr, Macro, MacroDelimiter, PatType, Path, PathArguments, - PathSegment, ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, + Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, ReturnType, Signature, + Token, Type, TypeArray, TypePath, TypePtr, }; const CUDA_RS: &'static str = include_str! {"cuda.rs"}; @@ -109,8 +108,11 @@ impl VisitMut for FixAbi { // Then macro goes through every function in rust.rs, and for every fn `foo`: // * if `foo` is contained in `override_fns` then pass it into `override_macro` // * if `foo` is not contained in `override_fns` pass it to `normal_macro` -// Both `override_macro` and `normal_macro` expect this format: -// macro_foo!("system" fn cuCtxDetach(ctx: CUcontext) -> CUresult) +// Both `override_macro` and `normal_macro` expect semicolon-separated list: +// macro_foo!( +// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult; +// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult +// ) // Additionally, it does a fixup of CUDA types so they get prefixed with `type_path` #[proc_macro] pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { @@ -121,7 +123,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { .iter() .map(ToString::to_string) .collect::<FxHashSet<_>>(); - cuda_module + let (normal_macro_args, override_macro_args): (Vec<_>, Vec<_>) = cuda_module .items .into_iter() .filter_map(|item| match item { @@ -136,12 +138,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { }, .. }) => { - let path = if override_fns.contains(&ident.to_string()) { - &input.override_macro - } else { - &input.normal_macro - } - .clone(); + let use_normal_macro = !override_fns.contains(&ident.to_string()); let inputs = inputs .into_iter() .map(|fn_arg| match fn_arg { @@ -158,30 +155,42 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { ReturnType::Default => unreachable!(), }; let type_path = input.type_path.clone(); - let tokens = quote! { - "system" fn #ident(#inputs) -> #type_path :: #output - }; - Some(Item::Macro(ItemMacro { - attrs: Vec::new(), - ident: None, - mac: Macro { - path, - bang_token: Token![!](Span::call_site()), - delimiter: MacroDelimiter::Brace(Brace { - span: Span::call_site(), - }), - tokens, + Some(( + quote! { + "system" fn #ident(#inputs) -> #type_path :: #output }, - semi_token: None, - })) + use_normal_macro, + )) } _ => unreachable!(), }, _ => None, }) - .map(Item::into_token_stream) - .collect::<proc_macro2::TokenStream>() - .into() + .partition(|(_, use_normal_macro)| *use_normal_macro); + let mut result = proc_macro2::TokenStream::new(); + if !normal_macro_args.is_empty() { + let punctuated_normal_macro_args = to_punctuated::<Token![;]>(normal_macro_args); + let macro_ = &input.normal_macro; + result.extend(iter::once(quote! { + #macro_ ! (#punctuated_normal_macro_args); + })); + } + if !override_macro_args.is_empty() { + let punctuated_override_macro_args = to_punctuated::<Token![;]>(override_macro_args); + let macro_ = &input.override_macro; + result.extend(iter::once(quote! { + #macro_ ! (#punctuated_override_macro_args); + })); + } + result.into() +} + +fn to_punctuated<P: ToTokens + Default>( + elms: Vec<(proc_macro2::TokenStream, bool)>, +) -> proc_macro2::TokenStream { + let mut collection = Punctuated::<proc_macro2::TokenStream, P>::new(); + collection.extend(elms.into_iter().map(|(token_stream, _)| token_stream)); + collection.into_token_stream() } fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> { |