aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_base')
-rw-r--r--cuda_base/README1
-rw-r--r--cuda_base/src/lib.rs17
2 files changed, 13 insertions, 5 deletions
diff --git a/cuda_base/README b/cuda_base/README
deleted file mode 100644
index 7ee6f45..0000000
--- a/cuda_base/README
+++ /dev/null
@@ -1 +0,0 @@
-bindgen build/wrapper.h -o src/cuda.rs --no-partialeq "CUDA_HOST_NODE_PARAMS_st" --with-derive-eq --allowlist-type="^CU.*" --allowlist-function="^cu.*" --allowlist-var="^CU.*" --default-enum-style=newtype --no-layout-tests --no-doc-comments --no-derive-debug --new-type-alias "^CUdevice_v\d+$|^CUdeviceptr_v\d+$" --must-use-type "cudaError_enum" --constified-enum "cudaError_enum" -- -I/usr/local/cuda/include \ No newline at end of file
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 366edd7..765af71 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -162,7 +162,13 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.ident
.to_string();
let known_modules = [
- "context", "device", "function", "link", "memory", "module", "pointer",
+ ("ctx", "context"),
+ ("device", "device"),
+ ("function", "function"),
+ ("link", "link"),
+ ("memory", "memory"),
+ ("module", "module"),
+ ("pointer", "pointer"),
];
let segments: Vec<String> = split(&fn_[2..]);
let fn_path = join(segments, &known_modules);
@@ -184,10 +190,13 @@ fn split(fn_: &str) -> Vec<String> {
result
}
-fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> {
+fn join(fn_: Vec<String>, known_modules: &[(&str, &str)]) -> Punctuated<Ident, Token![::]> {
let (prefix, suffix) = fn_.split_at(1);
- if known_modules.contains(&&*prefix[0]) {
- [&prefix[0], &suffix.join("_")]
+ if let Some((_, mod_name)) = known_modules
+ .iter()
+ .find(|(mod_prefix, _)| mod_prefix == &prefix[0])
+ {
+ [*mod_name, &suffix.join("_")]
.into_iter()
.map(|seg| Ident::new(seg, Span::call_site()))
.collect()