diff options
Diffstat (limited to 'cuda_base')
-rw-r--r-- | cuda_base/README | 1 | ||||
-rw-r--r-- | cuda_base/src/lib.rs | 17 |
2 files changed, 13 insertions, 5 deletions
diff --git a/cuda_base/README b/cuda_base/README deleted file mode 100644 index 7ee6f45..0000000 --- a/cuda_base/README +++ /dev/null @@ -1 +0,0 @@ -bindgen build/wrapper.h -o src/cuda.rs --no-partialeq "CUDA_HOST_NODE_PARAMS_st" --with-derive-eq --allowlist-type="^CU.*" --allowlist-function="^cu.*" --allowlist-var="^CU.*" --default-enum-style=newtype --no-layout-tests --no-doc-comments --no-derive-debug --new-type-alias "^CUdevice_v\d+$|^CUdeviceptr_v\d+$" --must-use-type "cudaError_enum" --constified-enum "cudaError_enum" -- -I/usr/local/cuda/include
\ No newline at end of file diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 366edd7..765af71 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -162,7 +162,13 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { .ident .to_string(); let known_modules = [ - "context", "device", "function", "link", "memory", "module", "pointer", + ("ctx", "context"), + ("device", "device"), + ("function", "function"), + ("link", "link"), + ("memory", "memory"), + ("module", "module"), + ("pointer", "pointer"), ]; let segments: Vec<String> = split(&fn_[2..]); let fn_path = join(segments, &known_modules); @@ -184,10 +190,13 @@ fn split(fn_: &str) -> Vec<String> { result } -fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> { +fn join(fn_: Vec<String>, known_modules: &[(&str, &str)]) -> Punctuated<Ident, Token![::]> { let (prefix, suffix) = fn_.split_at(1); - if known_modules.contains(&&*prefix[0]) { - [&prefix[0], &suffix.join("_")] + if let Some((_, mod_name)) = known_modules + .iter() + .find(|(mod_prefix, _)| mod_prefix == &prefix[0]) + { + [*mod_name, &suffix.join("_")] .into_iter() .map(|seg| Ident::new(seg, Span::call_site())) .collect() |