aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_bindgen/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_bindgen/src/main.rs')
-rw-r--r--zluda_bindgen/src/main.rs129
1 files changed, 113 insertions, 16 deletions
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs
index 7332254..bfa9d49 100644
--- a/zluda_bindgen/src/main.rs
+++ b/zluda_bindgen/src/main.rs
@@ -5,7 +5,7 @@ use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::
use syn::{
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, FnArg,
ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path,
- PathArguments, Signature, Type, TypePath, UseTree,
+ PathArguments, Signature, Type, TypePath, UseTree, PathSegment
};
fn main() {
@@ -14,6 +14,11 @@ fn main() {
&crate_root,
&["..", "ext", "hip_runtime-sys", "src", "lib.rs"],
);
+ generate_ml(&crate_root);
+ generate_cuda(&crate_root);
+}
+
+fn generate_cuda(crate_root: &PathBuf) {
let cuda_header = bindgen::Builder::default()
.use_core()
.rust_target(bindgen::RustTarget::Stable_1_77)
@@ -42,16 +47,91 @@ fn main() {
.unwrap()
.to_string();
let module: syn::File = syn::parse_str(&cuda_header).unwrap();
- generate_functions(&crate_root, &["..", "cuda_base", "src", "cuda.rs"], &module);
- generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module);
+ generate_functions(
+ &crate_root,
+ "cuda",
+ &["..", "cuda_base", "src", "cuda.rs"],
+ &module,
+ );
+ generate_types_cuda(
+ &crate_root,
+ &["..", "cuda_types", "src", "cuda.rs"],
+ &module,
+ );
generate_display(
&crate_root,
&["..", "zluda_dump", "src", "format_generated.rs"],
- "cuda_types",
+ &["cuda_types", "cuda"],
&module,
)
}
+fn generate_ml(crate_root: &PathBuf) {
+ let ml_header = bindgen::Builder::default()
+ .use_core()
+ .rust_target(bindgen::RustTarget::Stable_1_77)
+ .layout_tests(false)
+ .default_enum_style(bindgen::EnumVariation::NewType {
+ is_bitfield: false,
+ is_global: false,
+ })
+ .derive_hash(true)
+ .derive_eq(true)
+ .header("/usr/local/cuda/include/nvml.h")
+ .allowlist_type("^nvml.*")
+ .allowlist_function("^nvml.*")
+ .allowlist_var("^NVML.*")
+ .must_use_type("nvmlReturn_t")
+ .constified_enum("nvmlReturn_enum")
+ .generate()
+ .unwrap()
+ .to_string();
+ let mut module: syn::File = syn::parse_str(&ml_header).unwrap();
+ let mut converter = ConvertIntoRustResult {
+ type_: "nvmlReturn_t",
+ underlying_type: "nvmlReturn_enum",
+ new_error_type: "nvmlError_t",
+ error_prefix: ("NVML_ERROR_", "ERROR_"),
+ success: ("NVML_SUCCESS", "SUCCESS"),
+ constants: Vec::new(),
+ };
+ module.items = module
+ .items
+ .into_iter()
+ .filter_map(|item| match item {
+ Item::Const(const_) => converter.get_const(const_).map(Item::Const),
+ Item::Use(use_) => converter.get_use(use_).map(Item::Use),
+ Item::Type(type_) => converter.get_type(type_).map(Item::Type),
+ item => Some(item),
+ })
+ .collect::<Vec<_>>();
+ converter.flush(&mut module.items);
+ generate_functions(
+ &crate_root,
+ "nvml",
+ &["..", "cuda_base", "src", "nvml.rs"],
+ &module,
+ );
+ generate_types(
+ &crate_root,
+ &["..", "cuda_types", "src", "nvml.rs"],
+ &module,
+ );
+}
+
+fn generate_types(crate_root: &PathBuf, path: &[&str], module: &syn::File) {
+ let non_fn = module.items.iter().filter_map(|item| match item {
+ Item::ForeignMod(_) => None,
+ _ => Some(item),
+ });
+ let module: syn::File = parse_quote! {
+ #(#non_fn)*
+ };
+ let mut output = crate_root.clone();
+ output.extend(path);
+ write_rust_to_file(output, &prettyplease::unparse(&module))
+}
+
fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
let hiprt_header = bindgen::Builder::default()
.use_core()
@@ -125,7 +205,7 @@ fn add_send_sync(items: &mut Vec<Item>, arg: &[&str]) {
}
}
-fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) {
+fn generate_functions(output: &PathBuf, submodule: &str, path: &[&str], module: &syn::File) {
let fns_ = module.items.iter().filter_map(|item| match item {
Item::ForeignMod(extern_) => match &*extern_.items {
[ForeignItem::Fn(fn_)] => Some(fn_),
@@ -138,7 +218,8 @@ fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) {
#(#fns_)*
}
};
- syn::visit_mut::visit_file_mut(&mut PrependCudaPath, &mut module);
+ let submodule = Ident::new(submodule, Span::call_site());
+ syn::visit_mut::visit_file_mut(&mut PrependCudaPath { module: submodule }, &mut module);
syn::visit_mut::visit_file_mut(&mut RemoveVisibility, &mut module);
syn::visit_mut::visit_file_mut(&mut ExplicitReturnType, &mut module);
let mut output = output.clone();
@@ -146,7 +227,7 @@ fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) {
write_rust_to_file(output, &prettyplease::unparse(&module))
}
-fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
+fn generate_types_cuda(output: &PathBuf, path: &[&str], module: &syn::File) {
let mut module = module.clone();
let mut converter = ConvertIntoRustResult {
type_: "CUresult",
@@ -314,7 +395,9 @@ impl VisitMut for FixAbi {
}
}
-struct PrependCudaPath;
+struct PrependCudaPath {
+ module: Ident,
+}
impl VisitMut for PrependCudaPath {
fn visit_type_path_mut(&mut self, type_: &mut TypePath) {
@@ -322,7 +405,8 @@ impl VisitMut for PrependCudaPath {
match &*type_.path.segments[0].ident.to_string() {
"usize" | "f64" | "f32" => {}
_ => {
- *type_ = parse_quote! { cuda_types :: #type_ };
+ let module = &self.module;
+ *type_ = parse_quote! { cuda_types :: #module :: #type_ };
}
}
}
@@ -350,7 +434,7 @@ impl VisitMut for ExplicitReturnType {
fn generate_display(
output: &PathBuf,
path: &[&str],
- types_crate: &'static str,
+ types_crate: &[&'static str],
module: &syn::File,
) {
let ignore_types = [
@@ -419,7 +503,7 @@ fn generate_display(
}
struct DeriveDisplayState<'a> {
- types_crate: &'static str,
+ types_crate: Path,
ignore_types: FxHashSet<Ident>,
ignore_fns: FxHashSet<Ident>,
enums: FxHashMap<&'a Ident, Vec<&'a Ident>>,
@@ -430,12 +514,22 @@ struct DeriveDisplayState<'a> {
impl<'a> DeriveDisplayState<'a> {
fn new(
ignore_types: &[&'static str],
- types_crate: &'static str,
+ types_crate: &[&'static str],
ignore_fns: &[&'static str],
count_selectors: &[(&'static str, usize, usize)],
) -> Self {
+ let segments = types_crate
+ .iter()
+ .map(|seg| PathSegment {
+ ident: Ident::new(seg, Span::call_site()),
+ arguments: PathArguments::None,
+ })
+ .collect::<Punctuated<_, _>>();
DeriveDisplayState {
- types_crate,
+ types_crate: Path {
+ leading_colon: None,
+ segments,
+ },
ignore_types: ignore_types
.into_iter()
.map(|x| Ident::new(x, Span::call_site()))
@@ -469,8 +563,11 @@ fn cuda_derive_display_trait_for_item<'a>(
state: &mut DeriveDisplayState<'a>,
item: &'a Item,
) -> Option<syn::Item> {
- let path_prefix = Path::from(Ident::new(state.types_crate, Span::call_site()));
+ let path_prefix = & state.types_crate;
let path_prefix_iter = iter::repeat(&path_prefix);
+ let mut prepend_path = PrependCudaPath {
+ module: Ident::new("cuda", Span::call_site()),
+ };
match item {
Item::Const(const_) => {
if const_.ty.to_token_stream().to_string() == "cudaError_enum" {
@@ -490,7 +587,7 @@ fn cuda_derive_display_trait_for_item<'a>(
.iter()
.map(|fn_arg| {
let mut fn_arg = fn_arg.clone();
- syn::visit_mut::visit_fn_arg_mut(&mut PrependCudaPath, &mut fn_arg);
+ syn::visit_mut::visit_fn_arg_mut(&mut prepend_path, &mut fn_arg);
fn_arg
})
.collect::<Vec<_>>();
@@ -686,7 +783,7 @@ fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item {
})
});
parse_quote! {
- impl crate::format::CudaDisplay for cuda_types::CUresult {
+ impl crate::format::CudaDisplay for cuda_types::cuda::CUresult {
fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
match self {
Ok(()) => writer.write_all(b"CUDA_SUCCESS"),