aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r--cuda_base/src/lib.rs47
1 files changed, 47 insertions, 0 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 64e33ef..366edd7 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -1,6 +1,7 @@
extern crate proc_macro;
use proc_macro::TokenStream;
+use proc_macro2::Span;
use quote::{quote, ToTokens};
use rustc_hash::FxHashMap;
use std::iter;
@@ -148,3 +149,49 @@ impl VisitMut for FixFnSignatures {
s.inputs.pop_punct();
}
}
+
+#[proc_macro]
+pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
+ let mut path = parse_macro_input!(tokens as syn::Path);
+ let fn_ = path
+ .segments
+ .pop()
+ .unwrap()
+ .into_tuple()
+ .0
+ .ident
+ .to_string();
+ let known_modules = [
+ "context", "device", "function", "link", "memory", "module", "pointer",
+ ];
+ let segments: Vec<String> = split(&fn_[2..]);
+ let fn_path = join(segments, &known_modules);
+ quote! {
+ #path #fn_path
+ }
+ .into()
+}
+
+fn split(fn_: &str) -> Vec<String> {
+ let mut result = Vec::new();
+ for c in fn_.chars() {
+ if c.is_ascii_uppercase() {
+ result.push(c.to_ascii_lowercase().to_string());
+ } else {
+ result.last_mut().unwrap().push(c);
+ }
+ }
+ result
+}
+
+fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> {
+ let (prefix, suffix) = fn_.split_at(1);
+ if known_modules.contains(&&*prefix[0]) {
+ [&prefix[0], &suffix.join("_")]
+ .into_iter()
+ .map(|seg| Ident::new(seg, Span::call_site()))
+ .collect()
+ } else {
+ iter::once(Ident::new(&fn_.join("_"), Span::call_site())).collect()
+ }
+}