aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_bindgen/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_bindgen/src/main.rs')
-rw-r--r--zluda_bindgen/src/main.rs119
1 files changed, 96 insertions, 23 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs
index b7c7dac..3d7ea2e 100644
--- a/zluda_bindgen/src/main.rs
+++ b/zluda_bindgen/src/main.rs
@@ -10,33 +10,34 @@ use syn::{
fn main() {
let crate_root = PathBuf::from_str(env!("CARGO_MANIFEST_DIR")).unwrap();
+ generate_hip_runtime(
+ &crate_root,
+ &["..", "ext", "hip_runtime-sys", "src", "lib.rs"],
+ );
let cuda_header = bindgen::Builder::default()
.use_core()
- .header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h"))
- .no_partialeq("CUDA_HOST_NODE_PARAMS_st")
- .derive_eq(true)
- .allowlist_type("^CU.*")
- .allowlist_function("^cu.*")
- .allowlist_var("^CU.*")
+ .rust_target(bindgen::RustTarget::Stable_1_77)
+ .layout_tests(false)
.default_enum_style(bindgen::EnumVariation::NewType {
is_bitfield: false,
is_global: false,
})
- .layout_tests(false)
- .new_type_alias(r"^CUdevice_v\d+$")
- .new_type_alias(r"^CUdeviceptr_v\d+$")
+ .derive_eq(true)
+ .header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h"))
+ .allowlist_type("^CU.*")
+ .allowlist_function("^cu.*")
+ .allowlist_var("^CU.*")
.must_use_type("cudaError_enum")
.constified_enum("cudaError_enum")
+ .no_partialeq("CUDA_HOST_NODE_PARAMS_st")
+ .new_type_alias(r"^CUdevice_v\d+$")
+ .new_type_alias(r"^CUdeviceptr_v\d+$")
.clang_args(["-I/usr/local/cuda/include"])
.generate()
.unwrap()
.to_string();
let module: syn::File = syn::parse_str(&cuda_header).unwrap();
- generate_functions(
- &crate_root,
- &["..", "cuda_base", "src", "cuda.rs"],
- &module,
- );
+ generate_functions(&crate_root, &["..", "cuda_base", "src", "cuda.rs"], &module);
generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module);
generate_display(
&crate_root,
@@ -46,6 +47,68 @@ fn main() {
)
}
+fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
+ let hiprt_header = bindgen::Builder::default()
+ .use_core()
+ .rust_target(bindgen::RustTarget::Stable_1_77)
+ .layout_tests(false)
+ .default_enum_style(bindgen::EnumVariation::NewType {
+ is_bitfield: false,
+ is_global: false,
+ })
+ .derive_eq(true)
+ .header("/opt/rocm/include/hip/hip_runtime_api.h")
+ .allowlist_type("^hip.*")
+ .allowlist_function("^hip.*")
+ .allowlist_var("^hip.*")
+ .must_use_type("hipError_t")
+ .constified_enum("hipError_t")
+ .new_type_alias("^hipDeviceptr_t$")
+ .new_type_alias("^hipModule_t$")
+ .clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__"])
+ .generate()
+ .unwrap()
+ .to_string();
+ let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap();
+ let mut converter = ConvertIntoRustResult {
+ type_: "hipError_t",
+ underlying_type: "hipError_t",
+ new_error_type: "hipErrorCode_t",
+ error_prefix: ("hipError", "Error"),
+ success: ("hipSuccess", "Success"),
+ constants: Vec::new(),
+ };
+ module.items = module
+ .items
+ .into_iter()
+ .filter_map(|item| match item {
+ Item::Const(const_) => converter.get_const(const_).map(Item::Const),
+ Item::Use(use_) => converter.get_use(use_).map(Item::Use),
+ Item::Type(type_) => converter.get_type(type_).map(Item::Type),
+ item => Some(item),
+ })
+ .collect::<Vec<_>>();
+ converter.flush(&mut module.items);
+ add_send_sync(&mut module.items, &["hipModule_t"]);
+ let mut output = output.clone();
+ output.extend(path);
+ write_rust_to_file(output, &prettyplease::unparse(&module))
+}
+
+fn add_send_sync(items: &mut Vec<Item>, arg: &[&str]) {
+ for type_ in arg {
+ let type_ = Ident::new(type_, Span::call_site());
+ items.extend([
+ parse_quote! {
+ unsafe impl Send for #type_ {}
+ },
+ parse_quote! {
+ unsafe impl Sync for #type_ {}
+ },
+ ]);
+ }
+}
+
fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) {
let fns_ = module.items.iter().filter_map(|item| match item {
Item::ForeignMod(extern_) => match &*extern_.items {
@@ -73,7 +136,7 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
type_: "CUresult",
underlying_type: "cudaError_enum",
new_error_type: "CUerror",
- error_prefix: ("CUDA_ERROR", "ERROR"),
+ error_prefix: ("CUDA_ERROR_", "ERROR_"),
success: ("CUDA_SUCCESS", "SUCCESS"),
constants: Vec::new(),
};
@@ -84,6 +147,7 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
Item::ForeignMod(_) => None,
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
+ Item::Type(type_) => converter.get_type(type_).map(Item::Type),
Item::Struct(mut struct_) => {
let ident_string = struct_.ident.to_string();
match &*ident_string {
@@ -105,6 +169,13 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
})
.collect::<Vec<_>>();
converter.flush(&mut module.items);
+ module.items.push(parse_quote! {
+ impl From<hip_runtime_sys::hipErrorCode_t> for CUerror {
+ fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
+ Self(error.0)
+ }
+ }
+ });
syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module);
let mut output = output.clone();
output.extend(path);
@@ -163,9 +234,9 @@ impl ConvertIntoRustResult {
const #success: #type_ = #type_::Ok(());
});
} else {
- let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len() + 1;
+ let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len();
let variant_ident =
- format_ident!("{}_{}", self.error_prefix.1, &ident[old_prefix_len..]);
+ format_ident!("{}{}", self.error_prefix.1, &ident[old_prefix_len..]);
let error_ident = format_ident!("{}", &ident[old_prefix_len..]);
let expr = &const_.expr;
result_variants.push(quote! {
@@ -193,15 +264,17 @@ impl ConvertIntoRustResult {
const _: fn() = || {
let _ = std::mem::transmute::<#type_, u32>;
};
-
- impl From<hip_runtime_sys::hipErrorCode_t> for #new_error_type {
- fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
- Self(error.0)
- }
- }
};
items.extend(extra_items);
}
+
+ fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> {
+ if type_.ident.to_string() == self.type_ {
+ None
+ } else {
+ Some(type_)
+ }
+ }
}
struct FixAbi;