diff options
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r-- | cuda_base/src/lib.rs | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 64e33ef..366edd7 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -1,6 +1,7 @@ extern crate proc_macro; use proc_macro::TokenStream; +use proc_macro2::Span; use quote::{quote, ToTokens}; use rustc_hash::FxHashMap; use std::iter; @@ -148,3 +149,49 @@ impl VisitMut for FixFnSignatures { s.inputs.pop_punct(); } } + +#[proc_macro] +pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { + let mut path = parse_macro_input!(tokens as syn::Path); + let fn_ = path + .segments + .pop() + .unwrap() + .into_tuple() + .0 + .ident + .to_string(); + let known_modules = [ + "context", "device", "function", "link", "memory", "module", "pointer", + ]; + let segments: Vec<String> = split(&fn_[2..]); + let fn_path = join(segments, &known_modules); + quote! { + #path #fn_path + } + .into() +} + +fn split(fn_: &str) -> Vec<String> { + let mut result = Vec::new(); + for c in fn_.chars() { + if c.is_ascii_uppercase() { + result.push(c.to_ascii_lowercase().to_string()); + } else { + result.last_mut().unwrap().push(c); + } + } + result +} + +fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> { + let (prefix, suffix) = fn_.split_at(1); + if known_modules.contains(&&*prefix[0]) { + [&prefix[0], &suffix.join("_")] + .into_iter() + .map(|seg| Ident::new(seg, Span::call_site())) + .collect() + } else { + iter::once(Ident::new(&fn_.join("_"), Span::call_site())).collect() + } +} |