aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_bindgen/src/main.rs
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-11-25 04:08:31 +0000
committerAndrzej Janik <[email protected]>2024-11-25 04:08:31 +0000
commitc461cefd7d57edd430d74780e90d25859f3b7472 (patch)
tree5b2fb1214d1de6bdb029e6d1cbf488016a44d967 /zluda_bindgen/src/main.rs
parent9f677e23c022955d552f2d530488ef51a95f0d6c (diff)
downloadZLUDA-c461cefd7d57edd430d74780e90d25859f3b7472.tar.gz
ZLUDA-c461cefd7d57edd430d74780e90d25859f3b7472.zip
Rebindgen to emit send,sync,hash
Diffstat (limited to 'zluda_bindgen/src/main.rs')
-rw-r--r--zluda_bindgen/src/main.rs85
1 files changed, 61 insertions, 24 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs
index 3d7ea2e..7332254 100644
--- a/zluda_bindgen/src/main.rs
+++ b/zluda_bindgen/src/main.rs
@@ -3,9 +3,9 @@ use quote::{format_ident, quote, ToTokens};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr};
use syn::{
- parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FnArg, ForeignItem,
- ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, PathArguments,
- Signature, Type, TypePath, UseTree,
+ parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, FnArg,
+ ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path,
+ PathArguments, Signature, Type, TypePath, UseTree,
};
fn main() {
@@ -22,6 +22,7 @@ fn main() {
is_bitfield: false,
is_global: false,
})
+ .derive_hash(true)
.derive_eq(true)
.header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h"))
.allowlist_type("^CU.*")
@@ -30,8 +31,12 @@ fn main() {
.must_use_type("cudaError_enum")
.constified_enum("cudaError_enum")
.no_partialeq("CUDA_HOST_NODE_PARAMS_st")
- .new_type_alias(r"^CUdevice_v\d+$")
.new_type_alias(r"^CUdeviceptr_v\d+$")
+ .new_type_alias(r"^CUcontext$")
+ .new_type_alias(r"^CUstream$")
+ .new_type_alias(r"^CUmodule$")
+ .new_type_alias(r"^CUfunction$")
+ .new_type_alias(r"^CUlibrary$")
.clang_args(["-I/usr/local/cuda/include"])
.generate()
.unwrap()
@@ -56,6 +61,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
is_bitfield: false,
is_global: false,
})
+ .derive_hash(true)
.derive_eq(true)
.header("/opt/rocm/include/hip/hip_runtime_api.h")
.allowlist_type("^hip.*")
@@ -64,7 +70,9 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
.must_use_type("hipError_t")
.constified_enum("hipError_t")
.new_type_alias("^hipDeviceptr_t$")
+ .new_type_alias("^hipStream_t$")
.new_type_alias("^hipModule_t$")
+ .new_type_alias("^hipFunction_t$")
.clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__"])
.generate()
.unwrap()
@@ -89,7 +97,15 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
})
.collect::<Vec<_>>();
converter.flush(&mut module.items);
- add_send_sync(&mut module.items, &["hipModule_t"]);
+ add_send_sync(
+ &mut module.items,
+ &[
+ "hipDeviceptr_t",
+ "hipStream_t",
+ "hipModule_t",
+ "hipFunction_t",
+ ],
+ );
let mut output = output.clone();
output.extend(path);
write_rust_to_file(output, &prettyplease::unparse(&module))
@@ -176,6 +192,17 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
}
}
});
+ add_send_sync(
+ &mut module.items,
+ &[
+ "CUdeviceptr",
+ "CUcontext",
+ "CUstream",
+ "CUmodule",
+ "CUfunction",
+ "CUlibrary",
+ ],
+ );
syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module);
let mut output = output.clone();
output.extend(path);
@@ -252,7 +279,7 @@ impl ConvertIntoRustResult {
#(#error_variants)*
}
#[repr(transparent)]
- #[derive(Debug, Copy, Clone, PartialEq, Eq)]
+ #[derive(Debug, Hash, Copy, Clone, PartialEq, Eq)]
pub struct #new_error_type(pub ::core::num::NonZeroU32);
pub trait #type_trait {
@@ -327,6 +354,8 @@ fn generate_display(
module: &syn::File,
) {
let ignore_types = [
+ "CUdevice",
+ "CUdeviceptr_v1",
"CUarrayMapInfo_st",
"CUDA_RESOURCE_DESC_st",
"CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st",
@@ -545,9 +574,9 @@ fn cuda_derive_display_trait_for_item<'a>(
})
} else {
let struct_ = &item_struct.ident;
- let (first_field, rest_of_fields) = match item_struct.fields {
+ match item_struct.fields {
Fields::Named(ref fields) => {
- let mut all_idents = fields.named.iter().filter_map(|f| {
+ let mut rest_of_fields = fields.named.iter().filter_map(|f| {
let f_ident = f.ident.as_ref().unwrap();
let name = f_ident.to_string();
if name.starts_with("reserved") || name == "_unused" {
@@ -556,27 +585,35 @@ fn cuda_derive_display_trait_for_item<'a>(
Some(f_ident)
}
});
- let first = match all_idents.next() {
+ let first_field = match rest_of_fields.next() {
Some(f) => f,
None => return None,
};
- (first, all_idents)
+ Some(parse_quote! {
+ impl crate::format::CudaDisplay for #path_prefix :: #struct_ {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?;
+ crate::format::CudaDisplay::write(&self.#first_field, "", 0, writer)?;
+ #(
+ writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?;
+ crate::format::CudaDisplay::write(&self.#rest_of_fields, "", 0, writer)?;
+ )*
+ writer.write_all(b" }")
+ }
+ }
+ })
}
- _ => return None,
- };
- Some(parse_quote! {
- impl crate::format::CudaDisplay for #path_prefix :: #struct_ {
- fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
- writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?;
- crate::format::CudaDisplay::write(&self.#first_field, "", 0, writer)?;
- #(
- writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?;
- crate::format::CudaDisplay::write(&self.#rest_of_fields, "", 0, writer)?;
- )*
- writer.write_all(b" }")
- }
+ Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) if unnamed.len() == 1 => {
+ Some(parse_quote! {
+ impl crate::format::CudaDisplay for #path_prefix :: #struct_ {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ write!(writer, "{:p}", self.0)
+ }
+ }
+ })
}
- })
+ _ => return None,
+ }
}
}
Item::Type(item_type) => {