aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base/src
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_base/src')
-rw-r--r--cuda_base/src/lib.rs69
1 files changed, 39 insertions, 30 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 8b804d1..3f6f779 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -9,12 +9,11 @@ use quote::{format_ident, quote, ToTokens};
use rustc_hash::{FxHashMap, FxHashSet};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
-use syn::token::Brace;
use syn::visit_mut::VisitMut;
use syn::{
bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident,
- Item, ItemForeignMod, ItemMacro, LitStr, Macro, MacroDelimiter, PatType, Path, PathArguments,
- PathSegment, ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr,
+ Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, ReturnType, Signature,
+ Token, Type, TypeArray, TypePath, TypePtr,
};
const CUDA_RS: &'static str = include_str! {"cuda.rs"};
@@ -109,8 +108,11 @@ impl VisitMut for FixAbi {
// Then macro goes through every function in rust.rs, and for every fn `foo`:
// * if `foo` is contained in `override_fns` then pass it into `override_macro`
// * if `foo` is not contained in `override_fns` pass it to `normal_macro`
-// Both `override_macro` and `normal_macro` expect this format:
-// macro_foo!("system" fn cuCtxDetach(ctx: CUcontext) -> CUresult)
+// Both `override_macro` and `normal_macro` expect semicolon-separated list:
+// macro_foo!(
+// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult;
+// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult
+// )
// Additionally, it does a fixup of CUDA types so they get prefixed with `type_path`
#[proc_macro]
pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
@@ -121,7 +123,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
.iter()
.map(ToString::to_string)
.collect::<FxHashSet<_>>();
- cuda_module
+ let (normal_macro_args, override_macro_args): (Vec<_>, Vec<_>) = cuda_module
.items
.into_iter()
.filter_map(|item| match item {
@@ -136,12 +138,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
},
..
}) => {
- let path = if override_fns.contains(&ident.to_string()) {
- &input.override_macro
- } else {
- &input.normal_macro
- }
- .clone();
+ let use_normal_macro = !override_fns.contains(&ident.to_string());
let inputs = inputs
.into_iter()
.map(|fn_arg| match fn_arg {
@@ -158,30 +155,42 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
ReturnType::Default => unreachable!(),
};
let type_path = input.type_path.clone();
- let tokens = quote! {
- "system" fn #ident(#inputs) -> #type_path :: #output
- };
- Some(Item::Macro(ItemMacro {
- attrs: Vec::new(),
- ident: None,
- mac: Macro {
- path,
- bang_token: Token![!](Span::call_site()),
- delimiter: MacroDelimiter::Brace(Brace {
- span: Span::call_site(),
- }),
- tokens,
+ Some((
+ quote! {
+ "system" fn #ident(#inputs) -> #type_path :: #output
},
- semi_token: None,
- }))
+ use_normal_macro,
+ ))
}
_ => unreachable!(),
},
_ => None,
})
- .map(Item::into_token_stream)
- .collect::<proc_macro2::TokenStream>()
- .into()
+ .partition(|(_, use_normal_macro)| *use_normal_macro);
+ let mut result = proc_macro2::TokenStream::new();
+ if !normal_macro_args.is_empty() {
+ let punctuated_normal_macro_args = to_punctuated::<Token![;]>(normal_macro_args);
+ let macro_ = &input.normal_macro;
+ result.extend(iter::once(quote! {
+ #macro_ ! (#punctuated_normal_macro_args);
+ }));
+ }
+ if !override_macro_args.is_empty() {
+ let punctuated_override_macro_args = to_punctuated::<Token![;]>(override_macro_args);
+ let macro_ = &input.override_macro;
+ result.extend(iter::once(quote! {
+ #macro_ ! (#punctuated_override_macro_args);
+ }));
+ }
+ result.into()
+}
+
+fn to_punctuated<P: ToTokens + Default>(
+ elms: Vec<(proc_macro2::TokenStream, bool)>,
+) -> proc_macro2::TokenStream {
+ let mut collection = Punctuated::<proc_macro2::TokenStream, P>::new();
+ collection.extend(elms.into_iter().map(|(token_stream, _)| token_stream));
+ collection.into_token_stream()
}
fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> {