diff options
Diffstat (limited to 'zluda_bindgen/src/main.rs')
-rw-r--r-- | zluda_bindgen/src/main.rs | 85 |
1 files changed, 61 insertions, 24 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index 3d7ea2e..7332254 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -3,9 +3,9 @@ use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr}; use syn::{ - parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FnArg, ForeignItem, - ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, PathArguments, - Signature, Type, TypePath, UseTree, + parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, FnArg, + ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, + PathArguments, Signature, Type, TypePath, UseTree, }; fn main() { @@ -22,6 +22,7 @@ fn main() { is_bitfield: false, is_global: false, }) + .derive_hash(true) .derive_eq(true) .header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h")) .allowlist_type("^CU.*") @@ -30,8 +31,12 @@ fn main() { .must_use_type("cudaError_enum") .constified_enum("cudaError_enum") .no_partialeq("CUDA_HOST_NODE_PARAMS_st") - .new_type_alias(r"^CUdevice_v\d+$") .new_type_alias(r"^CUdeviceptr_v\d+$") + .new_type_alias(r"^CUcontext$") + .new_type_alias(r"^CUstream$") + .new_type_alias(r"^CUmodule$") + .new_type_alias(r"^CUfunction$") + .new_type_alias(r"^CUlibrary$") .clang_args(["-I/usr/local/cuda/include"]) .generate() .unwrap() @@ -56,6 +61,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { is_bitfield: false, is_global: false, }) + .derive_hash(true) .derive_eq(true) .header("/opt/rocm/include/hip/hip_runtime_api.h") .allowlist_type("^hip.*") @@ -64,7 +70,9 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { .must_use_type("hipError_t") .constified_enum("hipError_t") .new_type_alias("^hipDeviceptr_t$") + .new_type_alias("^hipStream_t$") .new_type_alias("^hipModule_t$") + .new_type_alias("^hipFunction_t$") .clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__"]) .generate() .unwrap() @@ -89,7 +97,15 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { }) .collect::<Vec<_>>(); converter.flush(&mut module.items); - add_send_sync(&mut module.items, &["hipModule_t"]); + add_send_sync( + &mut module.items, + &[ + "hipDeviceptr_t", + "hipStream_t", + "hipModule_t", + "hipFunction_t", + ], + ); let mut output = output.clone(); output.extend(path); write_rust_to_file(output, &prettyplease::unparse(&module)) @@ -176,6 +192,17 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) { } } }); + add_send_sync( + &mut module.items, + &[ + "CUdeviceptr", + "CUcontext", + "CUstream", + "CUmodule", + "CUfunction", + "CUlibrary", + ], + ); syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module); let mut output = output.clone(); output.extend(path); @@ -252,7 +279,7 @@ impl ConvertIntoRustResult { #(#error_variants)* } #[repr(transparent)] - #[derive(Debug, Copy, Clone, PartialEq, Eq)] + #[derive(Debug, Hash, Copy, Clone, PartialEq, Eq)] pub struct #new_error_type(pub ::core::num::NonZeroU32); pub trait #type_trait { @@ -327,6 +354,8 @@ fn generate_display( module: &syn::File, ) { let ignore_types = [ + "CUdevice", + "CUdeviceptr_v1", "CUarrayMapInfo_st", "CUDA_RESOURCE_DESC_st", "CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st", @@ -545,9 +574,9 @@ fn cuda_derive_display_trait_for_item<'a>( }) } else { let struct_ = &item_struct.ident; - let (first_field, rest_of_fields) = match item_struct.fields { + match item_struct.fields { Fields::Named(ref fields) => { - let mut all_idents = fields.named.iter().filter_map(|f| { + let mut rest_of_fields = fields.named.iter().filter_map(|f| { let f_ident = f.ident.as_ref().unwrap(); let name = f_ident.to_string(); if name.starts_with("reserved") || name == "_unused" { @@ -556,27 +585,35 @@ fn cuda_derive_display_trait_for_item<'a>( Some(f_ident) } }); - let first = match all_idents.next() { + let first_field = match rest_of_fields.next() { Some(f) => f, None => return None, }; - (first, all_idents) + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #struct_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?; + crate::format::CudaDisplay::write(&self.#first_field, "", 0, writer)?; + #( + writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?; + crate::format::CudaDisplay::write(&self.#rest_of_fields, "", 0, writer)?; + )* + writer.write_all(b" }") + } + } + }) } - _ => return None, - }; - Some(parse_quote! { - impl crate::format::CudaDisplay for #path_prefix :: #struct_ { - fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { - writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?; - crate::format::CudaDisplay::write(&self.#first_field, "", 0, writer)?; - #( - writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?; - crate::format::CudaDisplay::write(&self.#rest_of_fields, "", 0, writer)?; - )* - writer.write_all(b" }") - } + Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) if unnamed.len() == 1 => { + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #struct_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + write!(writer, "{:p}", self.0) + } + } + }) } - }) + _ => return None, + } } } Item::Type(item_type) => { |