diff options
Diffstat (limited to 'zluda_bindgen/src/main.rs')
-rw-r--r-- | zluda_bindgen/src/main.rs | 129 |
1 files changed, 113 insertions, 16 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index 7332254..bfa9d49 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -5,7 +5,7 @@ use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str:: use syn::{ parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, FnArg, ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, - PathArguments, Signature, Type, TypePath, UseTree, + PathArguments, Signature, Type, TypePath, UseTree, PathSegment }; fn main() { @@ -14,6 +14,11 @@ fn main() { &crate_root, &["..", "ext", "hip_runtime-sys", "src", "lib.rs"], ); + generate_ml(&crate_root); + generate_cuda(&crate_root); +} + +fn generate_cuda(crate_root: &PathBuf) { let cuda_header = bindgen::Builder::default() .use_core() .rust_target(bindgen::RustTarget::Stable_1_77) @@ -42,16 +47,91 @@ fn main() { .unwrap() .to_string(); let module: syn::File = syn::parse_str(&cuda_header).unwrap(); - generate_functions(&crate_root, &["..", "cuda_base", "src", "cuda.rs"], &module); - generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module); + generate_functions( + &crate_root, + "cuda", + &["..", "cuda_base", "src", "cuda.rs"], + &module, + ); + generate_types_cuda( + &crate_root, + &["..", "cuda_types", "src", "cuda.rs"], + &module, + ); generate_display( &crate_root, &["..", "zluda_dump", "src", "format_generated.rs"], - "cuda_types", + &["cuda_types", "cuda"], &module, ) } +fn generate_ml(crate_root: &PathBuf) { + let ml_header = bindgen::Builder::default() + .use_core() + .rust_target(bindgen::RustTarget::Stable_1_77) + .layout_tests(false) + .default_enum_style(bindgen::EnumVariation::NewType { + is_bitfield: false, + is_global: false, + }) + .derive_hash(true) + .derive_eq(true) + .header("/usr/local/cuda/include/nvml.h") + .allowlist_type("^nvml.*") + .allowlist_function("^nvml.*") + .allowlist_var("^NVML.*") + .must_use_type("nvmlReturn_t") + .constified_enum("nvmlReturn_enum") + .generate() + .unwrap() + .to_string(); + let mut module: syn::File = syn::parse_str(&ml_header).unwrap(); + let mut converter = ConvertIntoRustResult { + type_: "nvmlReturn_t", + underlying_type: "nvmlReturn_enum", + new_error_type: "nvmlError_t", + error_prefix: ("NVML_ERROR_", "ERROR_"), + success: ("NVML_SUCCESS", "SUCCESS"), + constants: Vec::new(), + }; + module.items = module + .items + .into_iter() + .filter_map(|item| match item { + Item::Const(const_) => converter.get_const(const_).map(Item::Const), + Item::Use(use_) => converter.get_use(use_).map(Item::Use), + Item::Type(type_) => converter.get_type(type_).map(Item::Type), + item => Some(item), + }) + .collect::<Vec<_>>(); + converter.flush(&mut module.items); + generate_functions( + &crate_root, + "nvml", + &["..", "cuda_base", "src", "nvml.rs"], + &module, + ); + generate_types( + &crate_root, + &["..", "cuda_types", "src", "nvml.rs"], + &module, + ); +} + +fn generate_types(crate_root: &PathBuf, path: &[&str], module: &syn::File) { + let non_fn = module.items.iter().filter_map(|item| match item { + Item::ForeignMod(_) => None, + _ => Some(item), + }); + let module: syn::File = parse_quote! { + #(#non_fn)* + }; + let mut output = crate_root.clone(); + output.extend(path); + write_rust_to_file(output, &prettyplease::unparse(&module)) +} + fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { let hiprt_header = bindgen::Builder::default() .use_core() @@ -125,7 +205,7 @@ fn add_send_sync(items: &mut Vec<Item>, arg: &[&str]) { } } -fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) { +fn generate_functions(output: &PathBuf, submodule: &str, path: &[&str], module: &syn::File) { let fns_ = module.items.iter().filter_map(|item| match item { Item::ForeignMod(extern_) => match &*extern_.items { [ForeignItem::Fn(fn_)] => Some(fn_), @@ -138,7 +218,8 @@ fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) { #(#fns_)* } }; - syn::visit_mut::visit_file_mut(&mut PrependCudaPath, &mut module); + let submodule = Ident::new(submodule, Span::call_site()); + syn::visit_mut::visit_file_mut(&mut PrependCudaPath { module: submodule }, &mut module); syn::visit_mut::visit_file_mut(&mut RemoveVisibility, &mut module); syn::visit_mut::visit_file_mut(&mut ExplicitReturnType, &mut module); let mut output = output.clone(); @@ -146,7 +227,7 @@ fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) { write_rust_to_file(output, &prettyplease::unparse(&module)) } -fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) { +fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) { let mut module = module.clone(); let mut converter = ConvertIntoRustResult { type_: "CUresult", @@ -314,7 +395,9 @@ impl VisitMut for FixAbi { } } -struct PrependCudaPath; +struct PrependCudaPath { + module: Ident, +} impl VisitMut for PrependCudaPath { fn visit_type_path_mut(&mut self, type_: &mut TypePath) { @@ -322,7 +405,8 @@ impl VisitMut for PrependCudaPath { match &*type_.path.segments[0].ident.to_string() { "usize" | "f64" | "f32" => {} _ => { - *type_ = parse_quote! { cuda_types :: #type_ }; + let module = &self.module; + *type_ = parse_quote! { cuda_types :: #module :: #type_ }; } } } @@ -350,7 +434,7 @@ impl VisitMut for ExplicitReturnType { fn generate_display( output: &PathBuf, path: &[&str], - types_crate: &'static str, + types_crate: &[&'static str], module: &syn::File, ) { let ignore_types = [ @@ -419,7 +503,7 @@ fn generate_display( } struct DeriveDisplayState<'a> { - types_crate: &'static str, + types_crate: Path, ignore_types: FxHashSet<Ident>, ignore_fns: FxHashSet<Ident>, enums: FxHashMap<&'a Ident, Vec<&'a Ident>>, @@ -430,12 +514,22 @@ struct DeriveDisplayState<'a> { impl<'a> DeriveDisplayState<'a> { fn new( ignore_types: &[&'static str], - types_crate: &'static str, + types_crate: &[&'static str], ignore_fns: &[&'static str], count_selectors: &[(&'static str, usize, usize)], ) -> Self { + let segments = types_crate + .iter() + .map(|seg| PathSegment { + ident: Ident::new(seg, Span::call_site()), + arguments: PathArguments::None, + }) + .collect::<Punctuated<_, _>>(); DeriveDisplayState { - types_crate, + types_crate: Path { + leading_colon: None, + segments, + }, ignore_types: ignore_types .into_iter() .map(|x| Ident::new(x, Span::call_site())) @@ -469,8 +563,11 @@ fn cuda_derive_display_trait_for_item<'a>( state: &mut DeriveDisplayState<'a>, item: &'a Item, ) -> Option<syn::Item> { - let path_prefix = Path::from(Ident::new(state.types_crate, Span::call_site())); + let path_prefix = & state.types_crate; let path_prefix_iter = iter::repeat(&path_prefix); + let mut prepend_path = PrependCudaPath { + module: Ident::new("cuda", Span::call_site()), + }; match item { Item::Const(const_) => { if const_.ty.to_token_stream().to_string() == "cudaError_enum" { @@ -490,7 +587,7 @@ fn cuda_derive_display_trait_for_item<'a>( .iter() .map(|fn_arg| { let mut fn_arg = fn_arg.clone(); - syn::visit_mut::visit_fn_arg_mut(&mut PrependCudaPath, &mut fn_arg); + syn::visit_mut::visit_fn_arg_mut(&mut prepend_path, &mut fn_arg); fn_arg }) .collect::<Vec<_>>(); @@ -686,7 +783,7 @@ fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item { }) }); parse_quote! { - impl crate::format::CudaDisplay for cuda_types::CUresult { + impl crate::format::CudaDisplay for cuda_types::cuda::CUresult { fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { match self { Ok(()) => writer.write_all(b"CUDA_SUCCESS"), |