aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_bindgen/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_bindgen/src/main.rs')
-rw-r--r--zluda_bindgen/src/main.rs107
1 files changed, 64 insertions, 43 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs
index 5e3de53..ebb357c 100644
--- a/zluda_bindgen/src/main.rs
+++ b/zluda_bindgen/src/main.rs
@@ -5,7 +5,7 @@ use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::
use syn::{
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FnArg, ForeignItem,
ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, PathArguments,
- Signature, Type, UseTree,
+ Signature, Type, TypePath, UseTree,
};
fn main() {
@@ -32,6 +32,11 @@ fn main() {
.unwrap()
.to_string();
let module: syn::File = syn::parse_str(&cuda_header).unwrap();
+ generate_functions(
+ &crate_root,
+ &["..", "cuda_base", "src", "cuda.rs"],
+ &module,
+ );
generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module);
generate_display(
&crate_root,
@@ -41,6 +46,27 @@ fn main() {
)
}
+fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) {
+ let fns_ = module.items.iter().filter_map(|item| match item {
+ Item::ForeignMod(extern_) => match &*extern_.items {
+ [ForeignItem::Fn(fn_)] => Some(fn_),
+ _ => unreachable!(),
+ },
+ _ => None,
+ });
+ let mut module: syn::File = parse_quote! {
+ extern "system" {
+ #(#fns_)*
+ }
+ };
+ syn::visit_mut::visit_file_mut(&mut PrependCudaPath, &mut module);
+ syn::visit_mut::visit_file_mut(&mut RemoveVisibility, &mut module);
+ syn::visit_mut::visit_file_mut(&mut ExplicitReturnType, &mut module);
+ let mut output = output.clone();
+ output.extend(path);
+ write_rust_to_file(output, &prettyplease::unparse(&module))
+}
+
fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
let mut module = module.clone();
let mut converter = ConvertIntoRustResult {
@@ -181,6 +207,39 @@ impl VisitMut for FixAbi {
}
}
+struct PrependCudaPath;
+
+impl VisitMut for PrependCudaPath {
+ fn visit_type_path_mut(&mut self, type_: &mut TypePath) {
+ if type_.path.segments.len() == 1 {
+ match &*type_.path.segments[0].ident.to_string() {
+ "usize" | "f64" | "f32" => {}
+ _ => {
+ *type_ = parse_quote! { cuda_types :: #type_ };
+ }
+ }
+ }
+ }
+}
+
+struct RemoveVisibility;
+
+impl VisitMut for RemoveVisibility {
+ fn visit_visibility_mut(&mut self, i: &mut syn::Visibility) {
+ *i = syn::Visibility::Inherited;
+ }
+}
+
+struct ExplicitReturnType;
+
+impl VisitMut for ExplicitReturnType {
+ fn visit_return_type_mut(&mut self, i: &mut syn::ReturnType) {
+ if let syn::ReturnType::Default = i {
+ *i = parse_quote! { -> {} };
+ }
+ }
+}
+
fn generate_display(
output: &PathBuf,
path: &[&str],
@@ -320,13 +379,10 @@ fn cuda_derive_display_trait_for_item<'a>(
}
let inputs = inputs
.iter()
- .map(|fn_arg| match fn_arg {
- FnArg::Typed(ref pat_type) => {
- let mut pat_type = pat_type.clone();
- pat_type.ty = prepend_cuda_path_to_type(&path_prefix, pat_type.ty);
- FnArg::Typed(pat_type)
- }
- _ => unreachable!(),
+ .map(|fn_arg| {
+ let mut fn_arg = fn_arg.clone();
+ syn::visit_mut::visit_fn_arg_mut(&mut PrependCudaPath, &mut fn_arg);
+ fn_arg
})
.collect::<Vec<_>>();
let inputs_iter = inputs.iter();
@@ -500,41 +556,6 @@ fn fn_arg_name(fn_arg: &FnArg) -> &Box<syn::Pat> {
name
}
-fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> {
- match *type_ {
- Type::Path(mut type_path) => {
- type_path.path = prepend_cuda_path_to_path(base_path, type_path.path);
- Box::new(Type::Path(type_path))
- }
- Type::Ptr(mut type_ptr) => {
- type_ptr.elem = prepend_cuda_path_to_type(base_path, type_ptr.elem);
- Box::new(Type::Ptr(type_ptr))
- }
- _ => unreachable!(),
- }
-}
-
-fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path {
- if path.leading_colon.is_some() {
- return path;
- }
- if path.segments.len() == 1 {
- let ident = path.segments[0].ident.to_string();
- if ident.starts_with("CU")
- || ident.starts_with("cu")
- || ident.starts_with("GL")
- || ident.starts_with("EGL")
- || ident.starts_with("Vdp")
- || ident == "HGPUNV"
- {
- let mut base_path = base_path.clone();
- base_path.segments.extend(path.segments);
- return base_path;
- }
- }
- path
-}
-
fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item {
let errors = derive_state.result_variants.iter().filter_map(|const_| {
let prefix = "cudaError_enum_";