diff options
author | Andrzej Janik <[email protected]> | 2024-11-25 06:17:14 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-11-25 06:17:14 +0100 |
commit | 502b0c957e1fb58f5b6df678a26b8758349f8eb4 (patch) | |
tree | 9f9d22a9293c08090f54f0f65681a8cb4e86bfc8 /cuda_base | |
parent | c461cefd7d57edd430d74780e90d25859f3b7472 (diff) | |
download | ZLUDA-502b0c957e1fb58f5b6df678a26b8758349f8eb4.tar.gz ZLUDA-502b0c957e1fb58f5b6df678a26b8758349f8eb4.zip |
Add more missing host-side code
Diffstat (limited to 'cuda_base')
-rw-r--r-- | cuda_base/src/lib.rs | 38 |
1 files changed, 19 insertions, 19 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 0cc1f53..833d372 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -150,6 +150,10 @@ impl VisitMut for FixFnSignatures { } } +const MODULES: &[&str] = &[ + "context", "device", "driver", "function", "link", "memory", "module", "pointer", +]; + #[proc_macro] pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { let mut path = parse_macro_input!(tokens as syn::Path); @@ -161,8 +165,9 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { .0 .ident .to_string(); + let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string()); let segments: Vec<String> = split(&fn_[2..]); // skip "cu" - let fn_path = join(segments); + let fn_path = join(segments, !already_has_module); quote! { #path #fn_path } @@ -181,23 +186,16 @@ fn split(fn_: &str) -> Vec<String> { result } -fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> { +fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> { fn full_form(segment: &str) -> Option<&[&str]> { Some(match segment { "ctx" => &["context"], + "func" => &["function"], + "mem" => &["memory"], "memcpy" => &["memory", "copy"], _ => return None, }) } - const MODULES: &[&str] = &[ - "context", - "device", - "function", - "link", - "memory", - "module", - "pointer" - ]; let mut normalized: Vec<&str> = Vec::new(); for segment in fn_.iter() { match full_form(segment) { @@ -205,18 +203,20 @@ fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> { None => normalized.push(&*segment), } } + if !find_module { + return [Ident::new(&normalized.join("_"), Span::call_site())] + .into_iter() + .collect(); + } if !MODULES.contains(&normalized[0]) { - let mut globalized = vec!["global"]; + let mut globalized = vec!["driver"]; globalized.extend(normalized); normalized = globalized; } let (module, path) = normalized.split_first().unwrap(); let path = path.join("_"); - let mut result = Punctuated::new(); - result.extend( - [module, &&*path] - .into_iter() - .map(|s| Ident::new(s, Span::call_site())), - ); - result + [module, &&*path] + .into_iter() + .map(|s| Ident::new(s, Span::call_site())) + .collect() } |