diff options
Diffstat (limited to 'zluda_bindgen/src')
-rw-r--r-- | zluda_bindgen/src/main.rs | 703 |
1 files changed, 703 insertions, 0 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs new file mode 100644 index 0000000..7332254 --- /dev/null +++ b/zluda_bindgen/src/main.rs @@ -0,0 +1,703 @@ +use proc_macro2::Span; +use quote::{format_ident, quote, ToTokens}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr}; +use syn::{ + parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, FnArg, + ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, + PathArguments, Signature, Type, TypePath, UseTree, +}; + +fn main() { + let crate_root = PathBuf::from_str(env!("CARGO_MANIFEST_DIR")).unwrap(); + generate_hip_runtime( + &crate_root, + &["..", "ext", "hip_runtime-sys", "src", "lib.rs"], + ); + let cuda_header = bindgen::Builder::default() + .use_core() + .rust_target(bindgen::RustTarget::Stable_1_77) + .layout_tests(false) + .default_enum_style(bindgen::EnumVariation::NewType { + is_bitfield: false, + is_global: false, + }) + .derive_hash(true) + .derive_eq(true) + .header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h")) + .allowlist_type("^CU.*") + .allowlist_function("^cu.*") + .allowlist_var("^CU.*") + .must_use_type("cudaError_enum") + .constified_enum("cudaError_enum") + .no_partialeq("CUDA_HOST_NODE_PARAMS_st") + .new_type_alias(r"^CUdeviceptr_v\d+$") + .new_type_alias(r"^CUcontext$") + .new_type_alias(r"^CUstream$") + .new_type_alias(r"^CUmodule$") + .new_type_alias(r"^CUfunction$") + .new_type_alias(r"^CUlibrary$") + .clang_args(["-I/usr/local/cuda/include"]) + .generate() + .unwrap() + .to_string(); + let module: syn::File = syn::parse_str(&cuda_header).unwrap(); + generate_functions(&crate_root, &["..", "cuda_base", "src", "cuda.rs"], &module); + generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module); + generate_display( + &crate_root, + &["..", "zluda_dump", "src", "format_generated.rs"], + "cuda_types", + &module, + ) +} + +fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { + let hiprt_header = bindgen::Builder::default() + .use_core() + .rust_target(bindgen::RustTarget::Stable_1_77) + .layout_tests(false) + .default_enum_style(bindgen::EnumVariation::NewType { + is_bitfield: false, + is_global: false, + }) + .derive_hash(true) + .derive_eq(true) + .header("/opt/rocm/include/hip/hip_runtime_api.h") + .allowlist_type("^hip.*") + .allowlist_function("^hip.*") + .allowlist_var("^hip.*") + .must_use_type("hipError_t") + .constified_enum("hipError_t") + .new_type_alias("^hipDeviceptr_t$") + .new_type_alias("^hipStream_t$") + .new_type_alias("^hipModule_t$") + .new_type_alias("^hipFunction_t$") + .clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__"]) + .generate() + .unwrap() + .to_string(); + let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap(); + let mut converter = ConvertIntoRustResult { + type_: "hipError_t", + underlying_type: "hipError_t", + new_error_type: "hipErrorCode_t", + error_prefix: ("hipError", "Error"), + success: ("hipSuccess", "Success"), + constants: Vec::new(), + }; + module.items = module + .items + .into_iter() + .filter_map(|item| match item { + Item::Const(const_) => converter.get_const(const_).map(Item::Const), + Item::Use(use_) => converter.get_use(use_).map(Item::Use), + Item::Type(type_) => converter.get_type(type_).map(Item::Type), + item => Some(item), + }) + .collect::<Vec<_>>(); + converter.flush(&mut module.items); + add_send_sync( + &mut module.items, + &[ + "hipDeviceptr_t", + "hipStream_t", + "hipModule_t", + "hipFunction_t", + ], + ); + let mut output = output.clone(); + output.extend(path); + write_rust_to_file(output, &prettyplease::unparse(&module)) +} + +fn add_send_sync(items: &mut Vec<Item>, arg: &[&str]) { + for type_ in arg { + let type_ = Ident::new(type_, Span::call_site()); + items.extend([ + parse_quote! { + unsafe impl Send for #type_ {} + }, + parse_quote! { + unsafe impl Sync for #type_ {} + }, + ]); + } +} + +fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) { + let fns_ = module.items.iter().filter_map(|item| match item { + Item::ForeignMod(extern_) => match &*extern_.items { + [ForeignItem::Fn(fn_)] => Some(fn_), + _ => unreachable!(), + }, + _ => None, + }); + let mut module: syn::File = parse_quote! { + extern "system" { + #(#fns_)* + } + }; + syn::visit_mut::visit_file_mut(&mut PrependCudaPath, &mut module); + syn::visit_mut::visit_file_mut(&mut RemoveVisibility, &mut module); + syn::visit_mut::visit_file_mut(&mut ExplicitReturnType, &mut module); + let mut output = output.clone(); + output.extend(path); + write_rust_to_file(output, &prettyplease::unparse(&module)) +} + +fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) { + let mut module = module.clone(); + let mut converter = ConvertIntoRustResult { + type_: "CUresult", + underlying_type: "cudaError_enum", + new_error_type: "CUerror", + error_prefix: ("CUDA_ERROR_", "ERROR_"), + success: ("CUDA_SUCCESS", "SUCCESS"), + constants: Vec::new(), + }; + module.items = module + .items + .into_iter() + .filter_map(|item| match item { + Item::ForeignMod(_) => None, + Item::Const(const_) => converter.get_const(const_).map(Item::Const), + Item::Use(use_) => converter.get_use(use_).map(Item::Use), + Item::Type(type_) => converter.get_type(type_).map(Item::Type), + Item::Struct(mut struct_) => { + let ident_string = struct_.ident.to_string(); + match &*ident_string { + "CUdeviceptr_v2" => { + struct_.fields = Fields::Unnamed(parse_quote! { + (pub *mut ::core::ffi::c_void) + }); + } + "CUuuid_st" => { + struct_.fields = Fields::Named(parse_quote! { + {pub bytes: [::core::ffi::c_uchar; 16usize]} + }); + } + _ => {} + } + Some(Item::Struct(struct_)) + } + item => Some(item), + }) + .collect::<Vec<_>>(); + converter.flush(&mut module.items); + module.items.push(parse_quote! { + impl From<hip_runtime_sys::hipErrorCode_t> for CUerror { + fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self { + Self(error.0) + } + } + }); + add_send_sync( + &mut module.items, + &[ + "CUdeviceptr", + "CUcontext", + "CUstream", + "CUmodule", + "CUfunction", + "CUlibrary", + ], + ); + syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module); + let mut output = output.clone(); + output.extend(path); + write_rust_to_file(output, &prettyplease::unparse(&module)) +} + +fn write_rust_to_file(path: impl AsRef<std::path::Path>, content: &str) { + let mut file = File::create(path).unwrap(); + file.write("// Generated automatically by zluda_bindgen\n// DO NOT EDIT MANUALLY\n#![allow(warnings)]\n".as_bytes()) + .unwrap(); + file.write(content.as_bytes()).unwrap(); +} + +struct ConvertIntoRustResult { + type_: &'static str, + underlying_type: &'static str, + new_error_type: &'static str, + error_prefix: (&'static str, &'static str), + success: (&'static str, &'static str), + constants: Vec<syn::ItemConst>, +} + +impl ConvertIntoRustResult { + fn get_const(&mut self, const_: syn::ItemConst) -> Option<syn::ItemConst> { + let name = const_.ident.to_string(); + if name.starts_with(self.underlying_type) { + self.constants.push(const_); + None + } else { + Some(const_) + } + } + + fn get_use(&mut self, use_: ItemUse) -> Option<ItemUse> { + if let UseTree::Path(ref path) = use_.tree { + if let UseTree::Rename(ref rename) = &*path.tree { + if rename.rename == self.type_ { + return None; + } + } + } + Some(use_) + } + + fn flush(self, items: &mut Vec<Item>) { + let type_ = format_ident!("{}", self.type_); + let type_trait = format_ident!("{}Consts", self.type_); + let new_error_type = format_ident!("{}", self.new_error_type); + let success = format_ident!("{}", self.success.1); + let mut result_variants = Vec::new(); + let mut error_variants = Vec::new(); + for const_ in self.constants.iter() { + let ident = const_.ident.to_string(); + if ident.ends_with(self.success.0) { + result_variants.push(quote! { + const #success: #type_ = #type_::Ok(()); + }); + } else { + let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len(); + let variant_ident = + format_ident!("{}{}", self.error_prefix.1, &ident[old_prefix_len..]); + let error_ident = format_ident!("{}", &ident[old_prefix_len..]); + let expr = &const_.expr; + result_variants.push(quote! { + const #variant_ident: #type_ = #type_::Err(#new_error_type::#error_ident); + }); + error_variants.push(quote! { + pub const #error_ident: #new_error_type = #new_error_type(unsafe { ::core::num::NonZeroU32::new_unchecked(#expr) }); + }); + } + } + let extra_items: Punctuated<syn::Item, syn::parse::Nothing> = parse_quote! { + impl #new_error_type { + #(#error_variants)* + } + #[repr(transparent)] + #[derive(Debug, Hash, Copy, Clone, PartialEq, Eq)] + pub struct #new_error_type(pub ::core::num::NonZeroU32); + + pub trait #type_trait { + #(#result_variants)* + } + impl #type_trait for #type_ {} + #[must_use] + pub type #type_ = ::core::result::Result<(), #new_error_type>; + const _: fn() = || { + let _ = std::mem::transmute::<#type_, u32>; + }; + }; + items.extend(extra_items); + } + + fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> { + if type_.ident.to_string() == self.type_ { + None + } else { + Some(type_) + } + } +} + +struct FixAbi; + +impl VisitMut for FixAbi { + fn visit_abi_mut(&mut self, i: &mut Abi) { + if let Some(ref mut name) = i.name { + *name = LitStr::new("system", Span::call_site()); + } + } +} + +struct PrependCudaPath; + +impl VisitMut for PrependCudaPath { + fn visit_type_path_mut(&mut self, type_: &mut TypePath) { + if type_.path.segments.len() == 1 { + match &*type_.path.segments[0].ident.to_string() { + "usize" | "f64" | "f32" => {} + _ => { + *type_ = parse_quote! { cuda_types :: #type_ }; + } + } + } + } +} + +struct RemoveVisibility; + +impl VisitMut for RemoveVisibility { + fn visit_visibility_mut(&mut self, i: &mut syn::Visibility) { + *i = syn::Visibility::Inherited; + } +} + +struct ExplicitReturnType; + +impl VisitMut for ExplicitReturnType { + fn visit_return_type_mut(&mut self, i: &mut syn::ReturnType) { + if let syn::ReturnType::Default = i { + *i = parse_quote! { -> {} }; + } + } +} + +fn generate_display( + output: &PathBuf, + path: &[&str], + types_crate: &'static str, + module: &syn::File, +) { + let ignore_types = [ + "CUdevice", + "CUdeviceptr_v1", + "CUarrayMapInfo_st", + "CUDA_RESOURCE_DESC_st", + "CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st", + "CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st", + "CUexecAffinityParam_st", + "CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st", + "CUstreamBatchMemOpParams_union_CUstreamMemOpWriteValueParams_st", + "CUuuid_st", + "HGPUNV", + "EGLint", + "EGLSyncKHR", + "EGLImageKHR", + "EGLStreamKHR", + "CUasyncNotificationInfo_st", + "CUgraphNodeParams_st", + "CUeglFrame_st", + "CUdevResource_st", + "CUlaunchAttribute_st", + "CUlaunchConfig_st", + ]; + let ignore_functions = [ + "cuGLGetDevices", + "cuGLGetDevices_v2", + "cuStreamSetAttribute", + "cuStreamSetAttribute_ptsz", + "cuStreamGetAttribute", + "cuStreamGetAttribute_ptsz", + "cuGraphKernelNodeGetAttribute", + "cuGraphKernelNodeSetAttribute", + ]; + let count_selectors = [ + ("cuCtxCreate_v3", 1, 2), + ("cuMemMapArrayAsync", 0, 1), + ("cuMemMapArrayAsync_ptsz", 0, 1), + ("cuStreamBatchMemOp", 2, 1), + ("cuStreamBatchMemOp_ptsz", 2, 1), + ("cuStreamBatchMemOp_v2", 2, 1), + ]; + let mut derive_state = DeriveDisplayState::new( + &ignore_types, + types_crate, + &ignore_functions, + &count_selectors, + ); + let mut items = module + .items + .iter() + .filter_map(|i| cuda_derive_display_trait_for_item(&mut derive_state, i)) + .collect::<Vec<_>>(); + items.push(curesult_display_trait(&derive_state)); + let mut output = output.clone(); + output.extend(path); + write_rust_to_file( + output, + &prettyplease::unparse(&syn::File { + shebang: None, + attrs: Vec::new(), + items, + }), + ); +} + +struct DeriveDisplayState<'a> { + types_crate: &'static str, + ignore_types: FxHashSet<Ident>, + ignore_fns: FxHashSet<Ident>, + enums: FxHashMap<&'a Ident, Vec<&'a Ident>>, + array_arguments: FxHashMap<(Ident, usize), usize>, + result_variants: Vec<&'a ItemConst>, +} + +impl<'a> DeriveDisplayState<'a> { + fn new( + ignore_types: &[&'static str], + types_crate: &'static str, + ignore_fns: &[&'static str], + count_selectors: &[(&'static str, usize, usize)], + ) -> Self { + DeriveDisplayState { + types_crate, + ignore_types: ignore_types + .into_iter() + .map(|x| Ident::new(x, Span::call_site())) + .collect(), + ignore_fns: ignore_fns + .into_iter() + .map(|x| Ident::new(x, Span::call_site())) + .collect(), + array_arguments: count_selectors + .into_iter() + .map(|(name, val, count)| ((Ident::new(name, Span::call_site()), *val), *count)) + .collect(), + enums: Default::default(), + result_variants: Vec::new(), + } + } + + fn record_enum_variant(&mut self, enum_: &'a Ident, variant: &'a Ident) { + match self.enums.entry(enum_) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().push(variant); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(vec![variant]); + } + } + } +} + +fn cuda_derive_display_trait_for_item<'a>( + state: &mut DeriveDisplayState<'a>, + item: &'a Item, +) -> Option<syn::Item> { + let path_prefix = Path::from(Ident::new(state.types_crate, Span::call_site())); + let path_prefix_iter = iter::repeat(&path_prefix); + match item { + Item::Const(const_) => { + if const_.ty.to_token_stream().to_string() == "cudaError_enum" { + state.result_variants.push(const_); + } + None + } + Item::ForeignMod(ItemForeignMod { items, .. }) => match items.last().unwrap() { + ForeignItem::Fn(ForeignItemFn { + sig: Signature { ident, inputs, .. }, + .. + }) => { + if state.ignore_fns.contains(ident) { + return None; + } + let inputs = inputs + .iter() + .map(|fn_arg| { + let mut fn_arg = fn_arg.clone(); + syn::visit_mut::visit_fn_arg_mut(&mut PrependCudaPath, &mut fn_arg); + fn_arg + }) + .collect::<Vec<_>>(); + let inputs_iter = inputs.iter(); + let original_fn_name = ident.to_string(); + let mut write_argument = inputs.iter().enumerate().map(|(index, fn_arg)| { + let name = fn_arg_name(fn_arg); + if let Some(length_index) = state.array_arguments.get(&(ident.clone(), index)) { + let length = fn_arg_name(&inputs[*length_index]); + quote! { + writer.write_all(concat!(stringify!(#name), ": ").as_bytes())?; + writer.write_all(b"[")?; + for i in 0..#length { + if i != 0 { + writer.write_all(b", ")?; + } + crate::format::CudaDisplay::write(unsafe { &*#name.add(i as usize) }, #original_fn_name, arg_idx, writer)?; + } + writer.write_all(b"]")?; + } + } else { + quote! { + writer.write_all(concat!(stringify!(#name), ": ").as_bytes())?; + crate::format::CudaDisplay::write(&#name, #original_fn_name, arg_idx, writer)?; + } + } + }); + let fn_name = format_ident!("write_{}", ident); + Some(match write_argument.next() { + Some(first_write_argument) => parse_quote! { + pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized), #(#inputs_iter,)*) -> std::io::Result<()> { + let mut arg_idx = 0usize; + writer.write_all(b"(")?; + #first_write_argument + #( + arg_idx += 1; + writer.write_all(b", ")?; + #write_argument + )* + writer.write_all(b")") + } + }, + None => parse_quote! { + pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + writer.write_all(b"()") + } + }, + }) + } + _ => unreachable!(), + }, + Item::Impl(ref item_impl) => { + let enum_ = match &*item_impl.self_ty { + Type::Path(ref path) => &path.path.segments.last().unwrap().ident, + _ => unreachable!(), + }; + let variant_ = match item_impl.items.last().unwrap() { + syn::ImplItem::Const(item_const) => &item_const.ident, + _ => unreachable!(), + }; + state.record_enum_variant(enum_, variant_); + None + } + Item::Struct(item_struct) => { + if state.ignore_types.contains(&item_struct.ident) { + return None; + } + if state.enums.contains_key(&item_struct.ident) { + let enum_ = &item_struct.ident; + let enum_iter = iter::repeat(&item_struct.ident); + let variants = state.enums.get(&item_struct.ident).unwrap().iter(); + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #enum_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + match self { + #(& #path_prefix_iter :: #enum_iter :: #variants => writer.write_all(stringify!(#variants).as_bytes()),)* + _ => write!(writer, "{}", self.0) + } + } + } + }) + } else { + let struct_ = &item_struct.ident; + match item_struct.fields { + Fields::Named(ref fields) => { + let mut rest_of_fields = fields.named.iter().filter_map(|f| { + let f_ident = f.ident.as_ref().unwrap(); + let name = f_ident.to_string(); + if name.starts_with("reserved") || name == "_unused" { + None + } else { + Some(f_ident) + } + }); + let first_field = match rest_of_fields.next() { + Some(f) => f, + None => return None, + }; + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #struct_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?; + crate::format::CudaDisplay::write(&self.#first_field, "", 0, writer)?; + #( + writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?; + crate::format::CudaDisplay::write(&self.#rest_of_fields, "", 0, writer)?; + )* + writer.write_all(b" }") + } + } + }) + } + Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) if unnamed.len() == 1 => { + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #struct_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + write!(writer, "{:p}", self.0) + } + } + }) + } + _ => return None, + } + } + } + Item::Type(item_type) => { + if state.ignore_types.contains(&item_type.ident) { + return None; + }; + match &*item_type.ty { + Type::Ptr(_) => { + let type_ = &item_type.ident; + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #type_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + write!(writer, "{:p}", *self) + } + } + }) + } + Type::Path(type_path) => { + if type_path.path.leading_colon.is_some() { + let option_seg = type_path.path.segments.last().unwrap(); + if option_seg.ident == "Option" { + match &option_seg.arguments { + PathArguments::AngleBracketed(generic) => match generic.args[0] { + syn::GenericArgument::Type(Type::BareFn(_)) => { + let type_ = &item_type.ident; + return Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #type_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) }) + } + } + }); + } + _ => unreachable!(), + }, + _ => unreachable!(), + } + } + } + None + } + _ => unreachable!(), + } + } + Item::Union(_) => None, + Item::Use(_) => None, + _ => unreachable!(), + } +} + +fn fn_arg_name(fn_arg: &FnArg) -> &Box<syn::Pat> { + let name = if let FnArg::Typed(t) = fn_arg { + &t.pat + } else { + unreachable!() + }; + name +} + +fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item { + let errors = derive_state.result_variants.iter().filter_map(|const_| { + let prefix = "cudaError_enum_"; + let text = &const_.ident.to_string()[prefix.len()..]; + if text == "CUDA_SUCCESS" { + return None; + } + let expr = &const_.expr; + Some(quote! { + #expr => writer.write_all(#text.as_bytes()), + }) + }); + parse_quote! { + impl crate::format::CudaDisplay for cuda_types::CUresult { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + match self { + Ok(()) => writer.write_all(b"CUDA_SUCCESS"), + Err(err) => { + match err.0.get() { + #(#errors)* + err => write!(writer, "{}", err) + } + } + } + } + } + } +} |