aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r--cuda_base/src/lib.rs17
1 files changed, 13 insertions, 4 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 366edd7..765af71 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -162,7 +162,13 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.ident
.to_string();
let known_modules = [
- "context", "device", "function", "link", "memory", "module", "pointer",
+ ("ctx", "context"),
+ ("device", "device"),
+ ("function", "function"),
+ ("link", "link"),
+ ("memory", "memory"),
+ ("module", "module"),
+ ("pointer", "pointer"),
];
let segments: Vec<String> = split(&fn_[2..]);
let fn_path = join(segments, &known_modules);
@@ -184,10 +190,13 @@ fn split(fn_: &str) -> Vec<String> {
result
}
-fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> {
+fn join(fn_: Vec<String>, known_modules: &[(&str, &str)]) -> Punctuated<Ident, Token![::]> {
let (prefix, suffix) = fn_.split_at(1);
- if known_modules.contains(&&*prefix[0]) {
- [&prefix[0], &suffix.join("_")]
+ if let Some((_, mod_name)) = known_modules
+ .iter()
+ .find(|(mod_prefix, _)| mod_prefix == &prefix[0])
+ {
+ [*mod_name, &suffix.join("_")]
.into_iter()
.map(|seg| Ident::new(seg, Span::call_site()))
.collect()