diff options
Diffstat (limited to 'zluda_bindgen/src/main.rs')
-rw-r--r-- | zluda_bindgen/src/main.rs | 107 |
1 files changed, 64 insertions, 43 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index 5e3de53..ebb357c 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -5,7 +5,7 @@ use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str:: use syn::{ parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FnArg, ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, PathArguments, - Signature, Type, UseTree, + Signature, Type, TypePath, UseTree, }; fn main() { @@ -32,6 +32,11 @@ fn main() { .unwrap() .to_string(); let module: syn::File = syn::parse_str(&cuda_header).unwrap(); + generate_functions( + &crate_root, + &["..", "cuda_base", "src", "cuda.rs"], + &module, + ); generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module); generate_display( &crate_root, @@ -41,6 +46,27 @@ fn main() { ) } +fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) { + let fns_ = module.items.iter().filter_map(|item| match item { + Item::ForeignMod(extern_) => match &*extern_.items { + [ForeignItem::Fn(fn_)] => Some(fn_), + _ => unreachable!(), + }, + _ => None, + }); + let mut module: syn::File = parse_quote! { + extern "system" { + #(#fns_)* + } + }; + syn::visit_mut::visit_file_mut(&mut PrependCudaPath, &mut module); + syn::visit_mut::visit_file_mut(&mut RemoveVisibility, &mut module); + syn::visit_mut::visit_file_mut(&mut ExplicitReturnType, &mut module); + let mut output = output.clone(); + output.extend(path); + write_rust_to_file(output, &prettyplease::unparse(&module)) +} + fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) { let mut module = module.clone(); let mut converter = ConvertIntoRustResult { @@ -181,6 +207,39 @@ impl VisitMut for FixAbi { } } +struct PrependCudaPath; + +impl VisitMut for PrependCudaPath { + fn visit_type_path_mut(&mut self, type_: &mut TypePath) { + if type_.path.segments.len() == 1 { + match &*type_.path.segments[0].ident.to_string() { + "usize" | "f64" | "f32" => {} + _ => { + *type_ = parse_quote! { cuda_types :: #type_ }; + } + } + } + } +} + +struct RemoveVisibility; + +impl VisitMut for RemoveVisibility { + fn visit_visibility_mut(&mut self, i: &mut syn::Visibility) { + *i = syn::Visibility::Inherited; + } +} + +struct ExplicitReturnType; + +impl VisitMut for ExplicitReturnType { + fn visit_return_type_mut(&mut self, i: &mut syn::ReturnType) { + if let syn::ReturnType::Default = i { + *i = parse_quote! { -> {} }; + } + } +} + fn generate_display( output: &PathBuf, path: &[&str], @@ -320,13 +379,10 @@ fn cuda_derive_display_trait_for_item<'a>( } let inputs = inputs .iter() - .map(|fn_arg| match fn_arg { - FnArg::Typed(ref pat_type) => { - let mut pat_type = pat_type.clone(); - pat_type.ty = prepend_cuda_path_to_type(&path_prefix, pat_type.ty); - FnArg::Typed(pat_type) - } - _ => unreachable!(), + .map(|fn_arg| { + let mut fn_arg = fn_arg.clone(); + syn::visit_mut::visit_fn_arg_mut(&mut PrependCudaPath, &mut fn_arg); + fn_arg }) .collect::<Vec<_>>(); let inputs_iter = inputs.iter(); @@ -500,41 +556,6 @@ fn fn_arg_name(fn_arg: &FnArg) -> &Box<syn::Pat> { name } -fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> { - match *type_ { - Type::Path(mut type_path) => { - type_path.path = prepend_cuda_path_to_path(base_path, type_path.path); - Box::new(Type::Path(type_path)) - } - Type::Ptr(mut type_ptr) => { - type_ptr.elem = prepend_cuda_path_to_type(base_path, type_ptr.elem); - Box::new(Type::Ptr(type_ptr)) - } - _ => unreachable!(), - } -} - -fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path { - if path.leading_colon.is_some() { - return path; - } - if path.segments.len() == 1 { - let ident = path.segments[0].ident.to_string(); - if ident.starts_with("CU") - || ident.starts_with("cu") - || ident.starts_with("GL") - || ident.starts_with("EGL") - || ident.starts_with("Vdp") - || ident == "HGPUNV" - { - let mut base_path = base_path.clone(); - base_path.segments.extend(path.segments); - return base_path; - } - } - path -} - fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item { let errors = derive_state.result_variants.iter().filter_map(|const_| { let prefix = "cudaError_enum_"; |