diff options
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r-- | cuda_base/src/lib.rs | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 366edd7..765af71 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -162,7 +162,13 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { .ident .to_string(); let known_modules = [ - "context", "device", "function", "link", "memory", "module", "pointer", + ("ctx", "context"), + ("device", "device"), + ("function", "function"), + ("link", "link"), + ("memory", "memory"), + ("module", "module"), + ("pointer", "pointer"), ]; let segments: Vec<String> = split(&fn_[2..]); let fn_path = join(segments, &known_modules); @@ -184,10 +190,13 @@ fn split(fn_: &str) -> Vec<String> { result } -fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> { +fn join(fn_: Vec<String>, known_modules: &[(&str, &str)]) -> Punctuated<Ident, Token![::]> { let (prefix, suffix) = fn_.split_at(1); - if known_modules.contains(&&*prefix[0]) { - [&prefix[0], &suffix.join("_")] + if let Some((_, mod_name)) = known_modules + .iter() + .find(|(mod_prefix, _)| mod_prefix == &prefix[0]) + { + [*mod_name, &suffix.join("_")] .into_iter() .map(|seg| Ident::new(seg, Span::call_site())) .collect() |