aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r--cuda_base/src/lib.rs62
1 files changed, 39 insertions, 23 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 765af71..0cc1f53 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -161,17 +161,8 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.0
.ident
.to_string();
- let known_modules = [
- ("ctx", "context"),
- ("device", "device"),
- ("function", "function"),
- ("link", "link"),
- ("memory", "memory"),
- ("module", "module"),
- ("pointer", "pointer"),
- ];
- let segments: Vec<String> = split(&fn_[2..]);
- let fn_path = join(segments, &known_modules);
+ let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
+ let fn_path = join(segments);
quote! {
#path #fn_path
}
@@ -190,17 +181,42 @@ fn split(fn_: &str) -> Vec<String> {
result
}
-fn join(fn_: Vec<String>, known_modules: &[(&str, &str)]) -> Punctuated<Ident, Token![::]> {
- let (prefix, suffix) = fn_.split_at(1);
- if let Some((_, mod_name)) = known_modules
- .iter()
- .find(|(mod_prefix, _)| mod_prefix == &prefix[0])
- {
- [*mod_name, &suffix.join("_")]
- .into_iter()
- .map(|seg| Ident::new(seg, Span::call_site()))
- .collect()
- } else {
- iter::once(Ident::new(&fn_.join("_"), Span::call_site())).collect()
+fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
+ fn full_form(segment: &str) -> Option<&[&str]> {
+ Some(match segment {
+ "ctx" => &["context"],
+ "memcpy" => &["memory", "copy"],
+ _ => return None,
+ })
}
+ const MODULES: &[&str] = &[
+ "context",
+ "device",
+ "function",
+ "link",
+ "memory",
+ "module",
+ "pointer"
+ ];
+ let mut normalized: Vec<&str> = Vec::new();
+ for segment in fn_.iter() {
+ match full_form(segment) {
+ Some(segments) => normalized.extend(segments.into_iter()),
+ None => normalized.push(&*segment),
+ }
+ }
+ if !MODULES.contains(&normalized[0]) {
+ let mut globalized = vec!["global"];
+ globalized.extend(normalized);
+ normalized = globalized;
+ }
+ let (module, path) = normalized.split_first().unwrap();
+ let path = path.join("_");
+ let mut result = Punctuated::new();
+ result.extend(
+ [module, &&*path]
+ .into_iter()
+ .map(|s| Ident::new(s, Span::call_site())),
+ );
+ result
}