diff options
Diffstat (limited to 'zluda_bindgen')
-rw-r--r-- | zluda_bindgen/Cargo.toml | 1 | ||||
-rw-r--r-- | zluda_bindgen/src/main.rs | 443 |
2 files changed, 427 insertions, 17 deletions
diff --git a/zluda_bindgen/Cargo.toml b/zluda_bindgen/Cargo.toml index df53d49..791ad2c 100644 --- a/zluda_bindgen/Cargo.toml +++ b/zluda_bindgen/Cargo.toml @@ -9,3 +9,4 @@ syn = { version = "2.0", features = ["full", "visit-mut"] } proc-macro2 = "1.0.89" quote = "1.0" prettyplease = "0.2.25" +rustc-hash = "1.1.0" diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index e90e07b..5e3de53 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -1,8 +1,11 @@ use proc_macro2::Span; -use quote::{format_ident, quote}; -use std::{path::PathBuf, str::FromStr}; +use quote::{format_ident, quote, ToTokens}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr}; use syn::{ - parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Item, ItemUse, LitStr, UseTree, + parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FnArg, ForeignItem, + ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, PathArguments, + Signature, Type, UseTree, }; fn main() { @@ -28,18 +31,18 @@ fn main() { .generate() .unwrap() .to_string(); - generate_types( - crate_root, - &["..", "cuda_types", "src", "lib.rs"], - cuda_header, - ); + let module: syn::File = syn::parse_str(&cuda_header).unwrap(); + generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module); + generate_display( + &crate_root, + &["..", "zluda_dump", "src", "format_generated.rs"], + "cuda_types", + &module, + ) } -fn generate_types(mut output: PathBuf, path: &[&str], cuda_header: String) { - let mut module: syn::File = syn::parse_str(&cuda_header).unwrap(); - module.attrs.push(parse_quote! { - #![allow(warnings)] - }); +fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) { + let mut module = module.clone(); let mut converter = ConvertIntoRustResult { type_: "CUresult", underlying_type: "cudaError_enum", @@ -55,15 +58,38 @@ fn generate_types(mut output: PathBuf, path: &[&str], cuda_header: String) { Item::ForeignMod(_) => None, Item::Const(const_) => converter.get_const(const_).map(Item::Const), Item::Use(use_) => converter.get_use(use_).map(Item::Use), + Item::Struct(mut struct_) => { + let ident_string = struct_.ident.to_string(); + match &*ident_string { + "CUdeviceptr_v2" => { + struct_.fields = Fields::Unnamed(parse_quote! { + (pub *mut ::core::ffi::c_void) + }); + } + "CUuuid_st" => { + struct_.fields = Fields::Named(parse_quote! { + {pub bytes: [::core::ffi::c_uchar; 16usize]} + }); + } + _ => {} + } + Some(Item::Struct(struct_)) + } item => Some(item), }) .collect::<Vec<_>>(); converter.flush(&mut module.items); syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module); - for segment in path { - output.push(segment); - } - std::fs::write(output, prettyplease::unparse(&module)).unwrap(); + let mut output = output.clone(); + output.extend(path); + write_rust_to_file(output, &prettyplease::unparse(&module)) +} + +fn write_rust_to_file(path: impl AsRef<std::path::Path>, content: &str) { + let mut file = File::create(path).unwrap(); + file.write("// Generated automatically by zluda_bindgen\n// DO NOT EDIT MANUALLY\n#![allow(warnings)]\n".as_bytes()) + .unwrap(); + file.write(content.as_bytes()).unwrap(); } struct ConvertIntoRustResult { @@ -154,3 +180,386 @@ impl VisitMut for FixAbi { } } } + +fn generate_display( + output: &PathBuf, + path: &[&str], + types_crate: &'static str, + module: &syn::File, +) { + let ignore_types = [ + "CUarrayMapInfo_st", + "CUDA_RESOURCE_DESC_st", + "CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st", + "CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st", + "CUexecAffinityParam_st", + "CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st", + "CUstreamBatchMemOpParams_union_CUstreamMemOpWriteValueParams_st", + "CUuuid_st", + "HGPUNV", + "EGLint", + "EGLSyncKHR", + "EGLImageKHR", + "EGLStreamKHR", + "CUasyncNotificationInfo_st", + "CUgraphNodeParams_st", + "CUeglFrame_st", + "CUdevResource_st", + "CUlaunchAttribute_st", + "CUlaunchConfig_st", + ]; + let ignore_functions = [ + "cuGLGetDevices", + "cuGLGetDevices_v2", + "cuStreamSetAttribute", + "cuStreamSetAttribute_ptsz", + "cuStreamGetAttribute", + "cuStreamGetAttribute_ptsz", + "cuGraphKernelNodeGetAttribute", + "cuGraphKernelNodeSetAttribute", + ]; + let count_selectors = [ + ("cuCtxCreate_v3", 1, 2), + ("cuMemMapArrayAsync", 0, 1), + ("cuMemMapArrayAsync_ptsz", 0, 1), + ("cuStreamBatchMemOp", 2, 1), + ("cuStreamBatchMemOp_ptsz", 2, 1), + ("cuStreamBatchMemOp_v2", 2, 1), + ]; + let mut derive_state = DeriveDisplayState::new( + &ignore_types, + types_crate, + &ignore_functions, + &count_selectors, + ); + let mut items = module + .items + .iter() + .filter_map(|i| cuda_derive_display_trait_for_item(&mut derive_state, i)) + .collect::<Vec<_>>(); + items.push(curesult_display_trait(&derive_state)); + let mut output = output.clone(); + output.extend(path); + write_rust_to_file( + output, + &prettyplease::unparse(&syn::File { + shebang: None, + attrs: Vec::new(), + items, + }), + ); +} + +struct DeriveDisplayState<'a> { + types_crate: &'static str, + ignore_types: FxHashSet<Ident>, + ignore_fns: FxHashSet<Ident>, + enums: FxHashMap<&'a Ident, Vec<&'a Ident>>, + array_arguments: FxHashMap<(Ident, usize), usize>, + result_variants: Vec<&'a ItemConst>, +} + +impl<'a> DeriveDisplayState<'a> { + fn new( + ignore_types: &[&'static str], + types_crate: &'static str, + ignore_fns: &[&'static str], + count_selectors: &[(&'static str, usize, usize)], + ) -> Self { + DeriveDisplayState { + types_crate, + ignore_types: ignore_types + .into_iter() + .map(|x| Ident::new(x, Span::call_site())) + .collect(), + ignore_fns: ignore_fns + .into_iter() + .map(|x| Ident::new(x, Span::call_site())) + .collect(), + array_arguments: count_selectors + .into_iter() + .map(|(name, val, count)| ((Ident::new(name, Span::call_site()), *val), *count)) + .collect(), + enums: Default::default(), + result_variants: Vec::new(), + } + } + + fn record_enum_variant(&mut self, enum_: &'a Ident, variant: &'a Ident) { + match self.enums.entry(enum_) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().push(variant); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(vec![variant]); + } + } + } +} + +fn cuda_derive_display_trait_for_item<'a>( + state: &mut DeriveDisplayState<'a>, + item: &'a Item, +) -> Option<syn::Item> { + let path_prefix = Path::from(Ident::new(state.types_crate, Span::call_site())); + let path_prefix_iter = iter::repeat(&path_prefix); + match item { + Item::Const(const_) => { + if const_.ty.to_token_stream().to_string() == "cudaError_enum" { + state.result_variants.push(const_); + } + None + } + Item::ForeignMod(ItemForeignMod { items, .. }) => match items.last().unwrap() { + ForeignItem::Fn(ForeignItemFn { + sig: Signature { ident, inputs, .. }, + .. + }) => { + if state.ignore_fns.contains(ident) { + return None; + } + let inputs = inputs + .iter() + .map(|fn_arg| match fn_arg { + FnArg::Typed(ref pat_type) => { + let mut pat_type = pat_type.clone(); + pat_type.ty = prepend_cuda_path_to_type(&path_prefix, pat_type.ty); + FnArg::Typed(pat_type) + } + _ => unreachable!(), + }) + .collect::<Vec<_>>(); + let inputs_iter = inputs.iter(); + let original_fn_name = ident.to_string(); + let mut write_argument = inputs.iter().enumerate().map(|(index, fn_arg)| { + let name = fn_arg_name(fn_arg); + if let Some(length_index) = state.array_arguments.get(&(ident.clone(), index)) { + let length = fn_arg_name(&inputs[*length_index]); + quote! { + writer.write_all(concat!(stringify!(#name), ": ").as_bytes())?; + writer.write_all(b"[")?; + for i in 0..#length { + if i != 0 { + writer.write_all(b", ")?; + } + crate::format::CudaDisplay::write(unsafe { &*#name.add(i as usize) }, #original_fn_name, arg_idx, writer)?; + } + writer.write_all(b"]")?; + } + } else { + quote! { + writer.write_all(concat!(stringify!(#name), ": ").as_bytes())?; + crate::format::CudaDisplay::write(&#name, #original_fn_name, arg_idx, writer)?; + } + } + }); + let fn_name = format_ident!("write_{}", ident); + Some(match write_argument.next() { + Some(first_write_argument) => parse_quote! { + pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized), #(#inputs_iter,)*) -> std::io::Result<()> { + let mut arg_idx = 0usize; + writer.write_all(b"(")?; + #first_write_argument + #( + arg_idx += 1; + writer.write_all(b", ")?; + #write_argument + )* + writer.write_all(b")") + } + }, + None => parse_quote! { + pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + writer.write_all(b"()") + } + }, + }) + } + _ => unreachable!(), + }, + Item::Impl(ref item_impl) => { + let enum_ = match &*item_impl.self_ty { + Type::Path(ref path) => &path.path.segments.last().unwrap().ident, + _ => unreachable!(), + }; + let variant_ = match item_impl.items.last().unwrap() { + syn::ImplItem::Const(item_const) => &item_const.ident, + _ => unreachable!(), + }; + state.record_enum_variant(enum_, variant_); + None + } + Item::Struct(item_struct) => { + if state.ignore_types.contains(&item_struct.ident) { + return None; + } + if state.enums.contains_key(&item_struct.ident) { + let enum_ = &item_struct.ident; + let enum_iter = iter::repeat(&item_struct.ident); + let variants = state.enums.get(&item_struct.ident).unwrap().iter(); + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #enum_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + match self { + #(& #path_prefix_iter :: #enum_iter :: #variants => writer.write_all(stringify!(#variants).as_bytes()),)* + _ => write!(writer, "{}", self.0) + } + } + } + }) + } else { + let struct_ = &item_struct.ident; + let (first_field, rest_of_fields) = match item_struct.fields { + Fields::Named(ref fields) => { + let mut all_idents = fields.named.iter().filter_map(|f| { + let f_ident = f.ident.as_ref().unwrap(); + let name = f_ident.to_string(); + if name.starts_with("reserved") || name == "_unused" { + None + } else { + Some(f_ident) + } + }); + let first = match all_idents.next() { + Some(f) => f, + None => return None, + }; + (first, all_idents) + } + _ => return None, + }; + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #struct_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?; + crate::format::CudaDisplay::write(&self.#first_field, "", 0, writer)?; + #( + writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?; + crate::format::CudaDisplay::write(&self.#rest_of_fields, "", 0, writer)?; + )* + writer.write_all(b" }") + } + } + }) + } + } + Item::Type(item_type) => { + if state.ignore_types.contains(&item_type.ident) { + return None; + }; + match &*item_type.ty { + Type::Ptr(_) => { + let type_ = &item_type.ident; + Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #type_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + write!(writer, "{:p}", *self) + } + } + }) + } + Type::Path(type_path) => { + if type_path.path.leading_colon.is_some() { + let option_seg = type_path.path.segments.last().unwrap(); + if option_seg.ident == "Option" { + match &option_seg.arguments { + PathArguments::AngleBracketed(generic) => match generic.args[0] { + syn::GenericArgument::Type(Type::BareFn(_)) => { + let type_ = &item_type.ident; + return Some(parse_quote! { + impl crate::format::CudaDisplay for #path_prefix :: #type_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) }) + } + } + }); + } + _ => unreachable!(), + }, + _ => unreachable!(), + } + } + } + None + } + _ => unreachable!(), + } + } + Item::Union(_) => None, + Item::Use(_) => None, + _ => unreachable!(), + } +} + +fn fn_arg_name(fn_arg: &FnArg) -> &Box<syn::Pat> { + let name = if let FnArg::Typed(t) = fn_arg { + &t.pat + } else { + unreachable!() + }; + name +} + +fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> { + match *type_ { + Type::Path(mut type_path) => { + type_path.path = prepend_cuda_path_to_path(base_path, type_path.path); + Box::new(Type::Path(type_path)) + } + Type::Ptr(mut type_ptr) => { + type_ptr.elem = prepend_cuda_path_to_type(base_path, type_ptr.elem); + Box::new(Type::Ptr(type_ptr)) + } + _ => unreachable!(), + } +} + +fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path { + if path.leading_colon.is_some() { + return path; + } + if path.segments.len() == 1 { + let ident = path.segments[0].ident.to_string(); + if ident.starts_with("CU") + || ident.starts_with("cu") + || ident.starts_with("GL") + || ident.starts_with("EGL") + || ident.starts_with("Vdp") + || ident == "HGPUNV" + { + let mut base_path = base_path.clone(); + base_path.segments.extend(path.segments); + return base_path; + } + } + path +} + +fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item { + let errors = derive_state.result_variants.iter().filter_map(|const_| { + let prefix = "cudaError_enum_"; + let text = &const_.ident.to_string()[prefix.len()..]; + if text == "CUDA_SUCCESS" { + return None; + } + let expr = &const_.expr; + Some(quote! { + #expr => writer.write_all(#text.as_bytes()), + }) + }); + parse_quote! { + impl crate::format::CudaDisplay for cuda_types::CUresult { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + match self { + Ok(()) => writer.write_all(b"CUDA_SUCCESS"), + Err(err) => { + match err.0.get() { + #(#errors)* + err => write!(writer, "{}", err) + } + } + } + } + } + } +} |