diff options
Diffstat (limited to 'zluda_bindgen/src/main.rs')
-rw-r--r-- | zluda_bindgen/src/main.rs | 119 |
1 files changed, 96 insertions, 23 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index b7c7dac..3d7ea2e 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -10,33 +10,34 @@ use syn::{ fn main() { let crate_root = PathBuf::from_str(env!("CARGO_MANIFEST_DIR")).unwrap(); + generate_hip_runtime( + &crate_root, + &["..", "ext", "hip_runtime-sys", "src", "lib.rs"], + ); let cuda_header = bindgen::Builder::default() .use_core() - .header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h")) - .no_partialeq("CUDA_HOST_NODE_PARAMS_st") - .derive_eq(true) - .allowlist_type("^CU.*") - .allowlist_function("^cu.*") - .allowlist_var("^CU.*") + .rust_target(bindgen::RustTarget::Stable_1_77) + .layout_tests(false) .default_enum_style(bindgen::EnumVariation::NewType { is_bitfield: false, is_global: false, }) - .layout_tests(false) - .new_type_alias(r"^CUdevice_v\d+$") - .new_type_alias(r"^CUdeviceptr_v\d+$") + .derive_eq(true) + .header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h")) + .allowlist_type("^CU.*") + .allowlist_function("^cu.*") + .allowlist_var("^CU.*") .must_use_type("cudaError_enum") .constified_enum("cudaError_enum") + .no_partialeq("CUDA_HOST_NODE_PARAMS_st") + .new_type_alias(r"^CUdevice_v\d+$") + .new_type_alias(r"^CUdeviceptr_v\d+$") .clang_args(["-I/usr/local/cuda/include"]) .generate() .unwrap() .to_string(); let module: syn::File = syn::parse_str(&cuda_header).unwrap(); - generate_functions( - &crate_root, - &["..", "cuda_base", "src", "cuda.rs"], - &module, - ); + generate_functions(&crate_root, &["..", "cuda_base", "src", "cuda.rs"], &module); generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module); generate_display( &crate_root, @@ -46,6 +47,68 @@ fn main() { ) } +fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { + let hiprt_header = bindgen::Builder::default() + .use_core() + .rust_target(bindgen::RustTarget::Stable_1_77) + .layout_tests(false) + .default_enum_style(bindgen::EnumVariation::NewType { + is_bitfield: false, + is_global: false, + }) + .derive_eq(true) + .header("/opt/rocm/include/hip/hip_runtime_api.h") + .allowlist_type("^hip.*") + .allowlist_function("^hip.*") + .allowlist_var("^hip.*") + .must_use_type("hipError_t") + .constified_enum("hipError_t") + .new_type_alias("^hipDeviceptr_t$") + .new_type_alias("^hipModule_t$") + .clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__"]) + .generate() + .unwrap() + .to_string(); + let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap(); + let mut converter = ConvertIntoRustResult { + type_: "hipError_t", + underlying_type: "hipError_t", + new_error_type: "hipErrorCode_t", + error_prefix: ("hipError", "Error"), + success: ("hipSuccess", "Success"), + constants: Vec::new(), + }; + module.items = module + .items + .into_iter() + .filter_map(|item| match item { + Item::Const(const_) => converter.get_const(const_).map(Item::Const), + Item::Use(use_) => converter.get_use(use_).map(Item::Use), + Item::Type(type_) => converter.get_type(type_).map(Item::Type), + item => Some(item), + }) + .collect::<Vec<_>>(); + converter.flush(&mut module.items); + add_send_sync(&mut module.items, &["hipModule_t"]); + let mut output = output.clone(); + output.extend(path); + write_rust_to_file(output, &prettyplease::unparse(&module)) +} + +fn add_send_sync(items: &mut Vec<Item>, arg: &[&str]) { + for type_ in arg { + let type_ = Ident::new(type_, Span::call_site()); + items.extend([ + parse_quote! { + unsafe impl Send for #type_ {} + }, + parse_quote! { + unsafe impl Sync for #type_ {} + }, + ]); + } +} + fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) { let fns_ = module.items.iter().filter_map(|item| match item { Item::ForeignMod(extern_) => match &*extern_.items { @@ -73,7 +136,7 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) { type_: "CUresult", underlying_type: "cudaError_enum", new_error_type: "CUerror", - error_prefix: ("CUDA_ERROR", "ERROR"), + error_prefix: ("CUDA_ERROR_", "ERROR_"), success: ("CUDA_SUCCESS", "SUCCESS"), constants: Vec::new(), }; @@ -84,6 +147,7 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) { Item::ForeignMod(_) => None, Item::Const(const_) => converter.get_const(const_).map(Item::Const), Item::Use(use_) => converter.get_use(use_).map(Item::Use), + Item::Type(type_) => converter.get_type(type_).map(Item::Type), Item::Struct(mut struct_) => { let ident_string = struct_.ident.to_string(); match &*ident_string { @@ -105,6 +169,13 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) { }) .collect::<Vec<_>>(); converter.flush(&mut module.items); + module.items.push(parse_quote! { + impl From<hip_runtime_sys::hipErrorCode_t> for CUerror { + fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self { + Self(error.0) + } + } + }); syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module); let mut output = output.clone(); output.extend(path); @@ -163,9 +234,9 @@ impl ConvertIntoRustResult { const #success: #type_ = #type_::Ok(()); }); } else { - let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len() + 1; + let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len(); let variant_ident = - format_ident!("{}_{}", self.error_prefix.1, &ident[old_prefix_len..]); + format_ident!("{}{}", self.error_prefix.1, &ident[old_prefix_len..]); let error_ident = format_ident!("{}", &ident[old_prefix_len..]); let expr = &const_.expr; result_variants.push(quote! { @@ -193,15 +264,17 @@ impl ConvertIntoRustResult { const _: fn() = || { let _ = std::mem::transmute::<#type_, u32>; }; - - impl From<hip_runtime_sys::hipErrorCode_t> for #new_error_type { - fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self { - Self(error.0) - } - } }; items.extend(extra_items); } + + fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> { + if type_.ident.to_string() == self.type_ { + None + } else { + Some(type_) + } + } } struct FixAbi; |