diff options
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r-- | cuda_base/src/lib.rs | 62 |
1 files changed, 39 insertions, 23 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 765af71..0cc1f53 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -161,17 +161,8 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { .0 .ident .to_string(); - let known_modules = [ - ("ctx", "context"), - ("device", "device"), - ("function", "function"), - ("link", "link"), - ("memory", "memory"), - ("module", "module"), - ("pointer", "pointer"), - ]; - let segments: Vec<String> = split(&fn_[2..]); - let fn_path = join(segments, &known_modules); + let segments: Vec<String> = split(&fn_[2..]); // skip "cu" + let fn_path = join(segments); quote! { #path #fn_path } @@ -190,17 +181,42 @@ fn split(fn_: &str) -> Vec<String> { result } -fn join(fn_: Vec<String>, known_modules: &[(&str, &str)]) -> Punctuated<Ident, Token![::]> { - let (prefix, suffix) = fn_.split_at(1); - if let Some((_, mod_name)) = known_modules - .iter() - .find(|(mod_prefix, _)| mod_prefix == &prefix[0]) - { - [*mod_name, &suffix.join("_")] - .into_iter() - .map(|seg| Ident::new(seg, Span::call_site())) - .collect() - } else { - iter::once(Ident::new(&fn_.join("_"), Span::call_site())).collect() +fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> { + fn full_form(segment: &str) -> Option<&[&str]> { + Some(match segment { + "ctx" => &["context"], + "memcpy" => &["memory", "copy"], + _ => return None, + }) } + const MODULES: &[&str] = &[ + "context", + "device", + "function", + "link", + "memory", + "module", + "pointer" + ]; + let mut normalized: Vec<&str> = Vec::new(); + for segment in fn_.iter() { + match full_form(segment) { + Some(segments) => normalized.extend(segments.into_iter()), + None => normalized.push(&*segment), + } + } + if !MODULES.contains(&normalized[0]) { + let mut globalized = vec!["global"]; + globalized.extend(normalized); + normalized = globalized; + } + let (module, path) = normalized.split_first().unwrap(); + let path = path.join("_"); + let mut result = Punctuated::new(); + result.extend( + [module, &&*path] + .into_iter() + .map(|s| Ident::new(s, Span::call_site())), + ); + result } |