aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-11-25 06:17:14 +0100
committerAndrzej Janik <[email protected]>2024-11-25 06:17:14 +0100
commit502b0c957e1fb58f5b6df678a26b8758349f8eb4 (patch)
tree9f9d22a9293c08090f54f0f65681a8cb4e86bfc8 /cuda_base
parentc461cefd7d57edd430d74780e90d25859f3b7472 (diff)
downloadZLUDA-502b0c957e1fb58f5b6df678a26b8758349f8eb4.tar.gz
ZLUDA-502b0c957e1fb58f5b6df678a26b8758349f8eb4.zip
Add more missing host-side code
Diffstat (limited to 'cuda_base')
-rw-r--r--cuda_base/src/lib.rs38
1 files changed, 19 insertions, 19 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 0cc1f53..833d372 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -150,6 +150,10 @@ impl VisitMut for FixFnSignatures {
}
}
+const MODULES: &[&str] = &[
+ "context", "device", "driver", "function", "link", "memory", "module", "pointer",
+];
+
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
let mut path = parse_macro_input!(tokens as syn::Path);
@@ -161,8 +165,9 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.0
.ident
.to_string();
+ let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
- let fn_path = join(segments);
+ let fn_path = join(segments, !already_has_module);
quote! {
#path #fn_path
}
@@ -181,23 +186,16 @@ fn split(fn_: &str) -> Vec<String> {
result
}
-fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
+fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
fn full_form(segment: &str) -> Option<&[&str]> {
Some(match segment {
"ctx" => &["context"],
+ "func" => &["function"],
+ "mem" => &["memory"],
"memcpy" => &["memory", "copy"],
_ => return None,
})
}
- const MODULES: &[&str] = &[
- "context",
- "device",
- "function",
- "link",
- "memory",
- "module",
- "pointer"
- ];
let mut normalized: Vec<&str> = Vec::new();
for segment in fn_.iter() {
match full_form(segment) {
@@ -205,18 +203,20 @@ fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
None => normalized.push(&*segment),
}
}
+ if !find_module {
+ return [Ident::new(&normalized.join("_"), Span::call_site())]
+ .into_iter()
+ .collect();
+ }
if !MODULES.contains(&normalized[0]) {
- let mut globalized = vec!["global"];
+ let mut globalized = vec!["driver"];
globalized.extend(normalized);
normalized = globalized;
}
let (module, path) = normalized.split_first().unwrap();
let path = path.join("_");
- let mut result = Punctuated::new();
- result.extend(
- [module, &&*path]
- .into_iter()
- .map(|s| Ident::new(s, Span::call_site())),
- );
- result
+ [module, &&*path]
+ .into_iter()
+ .map(|s| Ident::new(s, Span::call_site()))
+ .collect()
}