aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_bindgen
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_bindgen')
-rw-r--r--zluda_bindgen/Cargo.toml1
-rw-r--r--zluda_bindgen/src/main.rs443
2 files changed, 427 insertions, 17 deletions
diff --git a/zluda_bindgen/Cargo.toml b/zluda_bindgen/Cargo.toml
index df53d49..791ad2c 100644
--- a/zluda_bindgen/Cargo.toml
+++ b/zluda_bindgen/Cargo.toml
@@ -9,3 +9,4 @@ syn = { version = "2.0", features = ["full", "visit-mut"] }
proc-macro2 = "1.0.89"
quote = "1.0"
prettyplease = "0.2.25"
+rustc-hash = "1.1.0"
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs
index e90e07b..5e3de53 100644
--- a/zluda_bindgen/src/main.rs
+++ b/zluda_bindgen/src/main.rs
@@ -1,8 +1,11 @@
use proc_macro2::Span;
-use quote::{format_ident, quote};
-use std::{path::PathBuf, str::FromStr};
+use quote::{format_ident, quote, ToTokens};
+use rustc_hash::{FxHashMap, FxHashSet};
+use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr};
use syn::{
- parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Item, ItemUse, LitStr, UseTree,
+ parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FnArg, ForeignItem,
+ ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, PathArguments,
+ Signature, Type, UseTree,
};
fn main() {
@@ -28,18 +31,18 @@ fn main() {
.generate()
.unwrap()
.to_string();
- generate_types(
- crate_root,
- &["..", "cuda_types", "src", "lib.rs"],
- cuda_header,
- );
+ let module: syn::File = syn::parse_str(&cuda_header).unwrap();
+ generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module);
+ generate_display(
+ &crate_root,
+ &["..", "zluda_dump", "src", "format_generated.rs"],
+ "cuda_types",
+ &module,
+ )
}
-fn generate_types(mut output: PathBuf, path: &[&str], cuda_header: String) {
- let mut module: syn::File = syn::parse_str(&cuda_header).unwrap();
- module.attrs.push(parse_quote! {
- #![allow(warnings)]
- });
+fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
+ let mut module = module.clone();
let mut converter = ConvertIntoRustResult {
type_: "CUresult",
underlying_type: "cudaError_enum",
@@ -55,15 +58,38 @@ fn generate_types(mut output: PathBuf, path: &[&str], cuda_header: String) {
Item::ForeignMod(_) => None,
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
+ Item::Struct(mut struct_) => {
+ let ident_string = struct_.ident.to_string();
+ match &*ident_string {
+ "CUdeviceptr_v2" => {
+ struct_.fields = Fields::Unnamed(parse_quote! {
+ (pub *mut ::core::ffi::c_void)
+ });
+ }
+ "CUuuid_st" => {
+ struct_.fields = Fields::Named(parse_quote! {
+ {pub bytes: [::core::ffi::c_uchar; 16usize]}
+ });
+ }
+ _ => {}
+ }
+ Some(Item::Struct(struct_))
+ }
item => Some(item),
})
.collect::<Vec<_>>();
converter.flush(&mut module.items);
syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module);
- for segment in path {
- output.push(segment);
- }
- std::fs::write(output, prettyplease::unparse(&module)).unwrap();
+ let mut output = output.clone();
+ output.extend(path);
+ write_rust_to_file(output, &prettyplease::unparse(&module))
+}
+
+fn write_rust_to_file(path: impl AsRef<std::path::Path>, content: &str) {
+ let mut file = File::create(path).unwrap();
+ file.write("// Generated automatically by zluda_bindgen\n// DO NOT EDIT MANUALLY\n#![allow(warnings)]\n".as_bytes())
+ .unwrap();
+ file.write(content.as_bytes()).unwrap();
}
struct ConvertIntoRustResult {
@@ -154,3 +180,386 @@ impl VisitMut for FixAbi {
}
}
}
+
+fn generate_display(
+ output: &PathBuf,
+ path: &[&str],
+ types_crate: &'static str,
+ module: &syn::File,
+) {
+ let ignore_types = [
+ "CUarrayMapInfo_st",
+ "CUDA_RESOURCE_DESC_st",
+ "CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st",
+ "CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st",
+ "CUexecAffinityParam_st",
+ "CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st",
+ "CUstreamBatchMemOpParams_union_CUstreamMemOpWriteValueParams_st",
+ "CUuuid_st",
+ "HGPUNV",
+ "EGLint",
+ "EGLSyncKHR",
+ "EGLImageKHR",
+ "EGLStreamKHR",
+ "CUasyncNotificationInfo_st",
+ "CUgraphNodeParams_st",
+ "CUeglFrame_st",
+ "CUdevResource_st",
+ "CUlaunchAttribute_st",
+ "CUlaunchConfig_st",
+ ];
+ let ignore_functions = [
+ "cuGLGetDevices",
+ "cuGLGetDevices_v2",
+ "cuStreamSetAttribute",
+ "cuStreamSetAttribute_ptsz",
+ "cuStreamGetAttribute",
+ "cuStreamGetAttribute_ptsz",
+ "cuGraphKernelNodeGetAttribute",
+ "cuGraphKernelNodeSetAttribute",
+ ];
+ let count_selectors = [
+ ("cuCtxCreate_v3", 1, 2),
+ ("cuMemMapArrayAsync", 0, 1),
+ ("cuMemMapArrayAsync_ptsz", 0, 1),
+ ("cuStreamBatchMemOp", 2, 1),
+ ("cuStreamBatchMemOp_ptsz", 2, 1),
+ ("cuStreamBatchMemOp_v2", 2, 1),
+ ];
+ let mut derive_state = DeriveDisplayState::new(
+ &ignore_types,
+ types_crate,
+ &ignore_functions,
+ &count_selectors,
+ );
+ let mut items = module
+ .items
+ .iter()
+ .filter_map(|i| cuda_derive_display_trait_for_item(&mut derive_state, i))
+ .collect::<Vec<_>>();
+ items.push(curesult_display_trait(&derive_state));
+ let mut output = output.clone();
+ output.extend(path);
+ write_rust_to_file(
+ output,
+ &prettyplease::unparse(&syn::File {
+ shebang: None,
+ attrs: Vec::new(),
+ items,
+ }),
+ );
+}
+
+struct DeriveDisplayState<'a> {
+ types_crate: &'static str,
+ ignore_types: FxHashSet<Ident>,
+ ignore_fns: FxHashSet<Ident>,
+ enums: FxHashMap<&'a Ident, Vec<&'a Ident>>,
+ array_arguments: FxHashMap<(Ident, usize), usize>,
+ result_variants: Vec<&'a ItemConst>,
+}
+
+impl<'a> DeriveDisplayState<'a> {
+ fn new(
+ ignore_types: &[&'static str],
+ types_crate: &'static str,
+ ignore_fns: &[&'static str],
+ count_selectors: &[(&'static str, usize, usize)],
+ ) -> Self {
+ DeriveDisplayState {
+ types_crate,
+ ignore_types: ignore_types
+ .into_iter()
+ .map(|x| Ident::new(x, Span::call_site()))
+ .collect(),
+ ignore_fns: ignore_fns
+ .into_iter()
+ .map(|x| Ident::new(x, Span::call_site()))
+ .collect(),
+ array_arguments: count_selectors
+ .into_iter()
+ .map(|(name, val, count)| ((Ident::new(name, Span::call_site()), *val), *count))
+ .collect(),
+ enums: Default::default(),
+ result_variants: Vec::new(),
+ }
+ }
+
+ fn record_enum_variant(&mut self, enum_: &'a Ident, variant: &'a Ident) {
+ match self.enums.entry(enum_) {
+ hash_map::Entry::Occupied(mut entry) => {
+ entry.get_mut().push(variant);
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(vec![variant]);
+ }
+ }
+ }
+}
+
+fn cuda_derive_display_trait_for_item<'a>(
+ state: &mut DeriveDisplayState<'a>,
+ item: &'a Item,
+) -> Option<syn::Item> {
+ let path_prefix = Path::from(Ident::new(state.types_crate, Span::call_site()));
+ let path_prefix_iter = iter::repeat(&path_prefix);
+ match item {
+ Item::Const(const_) => {
+ if const_.ty.to_token_stream().to_string() == "cudaError_enum" {
+ state.result_variants.push(const_);
+ }
+ None
+ }
+ Item::ForeignMod(ItemForeignMod { items, .. }) => match items.last().unwrap() {
+ ForeignItem::Fn(ForeignItemFn {
+ sig: Signature { ident, inputs, .. },
+ ..
+ }) => {
+ if state.ignore_fns.contains(ident) {
+ return None;
+ }
+ let inputs = inputs
+ .iter()
+ .map(|fn_arg| match fn_arg {
+ FnArg::Typed(ref pat_type) => {
+ let mut pat_type = pat_type.clone();
+ pat_type.ty = prepend_cuda_path_to_type(&path_prefix, pat_type.ty);
+ FnArg::Typed(pat_type)
+ }
+ _ => unreachable!(),
+ })
+ .collect::<Vec<_>>();
+ let inputs_iter = inputs.iter();
+ let original_fn_name = ident.to_string();
+ let mut write_argument = inputs.iter().enumerate().map(|(index, fn_arg)| {
+ let name = fn_arg_name(fn_arg);
+ if let Some(length_index) = state.array_arguments.get(&(ident.clone(), index)) {
+ let length = fn_arg_name(&inputs[*length_index]);
+ quote! {
+ writer.write_all(concat!(stringify!(#name), ": ").as_bytes())?;
+ writer.write_all(b"[")?;
+ for i in 0..#length {
+ if i != 0 {
+ writer.write_all(b", ")?;
+ }
+ crate::format::CudaDisplay::write(unsafe { &*#name.add(i as usize) }, #original_fn_name, arg_idx, writer)?;
+ }
+ writer.write_all(b"]")?;
+ }
+ } else {
+ quote! {
+ writer.write_all(concat!(stringify!(#name), ": ").as_bytes())?;
+ crate::format::CudaDisplay::write(&#name, #original_fn_name, arg_idx, writer)?;
+ }
+ }
+ });
+ let fn_name = format_ident!("write_{}", ident);
+ Some(match write_argument.next() {
+ Some(first_write_argument) => parse_quote! {
+ pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized), #(#inputs_iter,)*) -> std::io::Result<()> {
+ let mut arg_idx = 0usize;
+ writer.write_all(b"(")?;
+ #first_write_argument
+ #(
+ arg_idx += 1;
+ writer.write_all(b", ")?;
+ #write_argument
+ )*
+ writer.write_all(b")")
+ }
+ },
+ None => parse_quote! {
+ pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ writer.write_all(b"()")
+ }
+ },
+ })
+ }
+ _ => unreachable!(),
+ },
+ Item::Impl(ref item_impl) => {
+ let enum_ = match &*item_impl.self_ty {
+ Type::Path(ref path) => &path.path.segments.last().unwrap().ident,
+ _ => unreachable!(),
+ };
+ let variant_ = match item_impl.items.last().unwrap() {
+ syn::ImplItem::Const(item_const) => &item_const.ident,
+ _ => unreachable!(),
+ };
+ state.record_enum_variant(enum_, variant_);
+ None
+ }
+ Item::Struct(item_struct) => {
+ if state.ignore_types.contains(&item_struct.ident) {
+ return None;
+ }
+ if state.enums.contains_key(&item_struct.ident) {
+ let enum_ = &item_struct.ident;
+ let enum_iter = iter::repeat(&item_struct.ident);
+ let variants = state.enums.get(&item_struct.ident).unwrap().iter();
+ Some(parse_quote! {
+ impl crate::format::CudaDisplay for #path_prefix :: #enum_ {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ match self {
+ #(& #path_prefix_iter :: #enum_iter :: #variants => writer.write_all(stringify!(#variants).as_bytes()),)*
+ _ => write!(writer, "{}", self.0)
+ }
+ }
+ }
+ })
+ } else {
+ let struct_ = &item_struct.ident;
+ let (first_field, rest_of_fields) = match item_struct.fields {
+ Fields::Named(ref fields) => {
+ let mut all_idents = fields.named.iter().filter_map(|f| {
+ let f_ident = f.ident.as_ref().unwrap();
+ let name = f_ident.to_string();
+ if name.starts_with("reserved") || name == "_unused" {
+ None
+ } else {
+ Some(f_ident)
+ }
+ });
+ let first = match all_idents.next() {
+ Some(f) => f,
+ None => return None,
+ };
+ (first, all_idents)
+ }
+ _ => return None,
+ };
+ Some(parse_quote! {
+ impl crate::format::CudaDisplay for #path_prefix :: #struct_ {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?;
+ crate::format::CudaDisplay::write(&self.#first_field, "", 0, writer)?;
+ #(
+ writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?;
+ crate::format::CudaDisplay::write(&self.#rest_of_fields, "", 0, writer)?;
+ )*
+ writer.write_all(b" }")
+ }
+ }
+ })
+ }
+ }
+ Item::Type(item_type) => {
+ if state.ignore_types.contains(&item_type.ident) {
+ return None;
+ };
+ match &*item_type.ty {
+ Type::Ptr(_) => {
+ let type_ = &item_type.ident;
+ Some(parse_quote! {
+ impl crate::format::CudaDisplay for #path_prefix :: #type_ {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ write!(writer, "{:p}", *self)
+ }
+ }
+ })
+ }
+ Type::Path(type_path) => {
+ if type_path.path.leading_colon.is_some() {
+ let option_seg = type_path.path.segments.last().unwrap();
+ if option_seg.ident == "Option" {
+ match &option_seg.arguments {
+ PathArguments::AngleBracketed(generic) => match generic.args[0] {
+ syn::GenericArgument::Type(Type::BareFn(_)) => {
+ let type_ = &item_type.ident;
+ return Some(parse_quote! {
+ impl crate::format::CudaDisplay for #path_prefix :: #type_ {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) })
+ }
+ }
+ });
+ }
+ _ => unreachable!(),
+ },
+ _ => unreachable!(),
+ }
+ }
+ }
+ None
+ }
+ _ => unreachable!(),
+ }
+ }
+ Item::Union(_) => None,
+ Item::Use(_) => None,
+ _ => unreachable!(),
+ }
+}
+
+fn fn_arg_name(fn_arg: &FnArg) -> &Box<syn::Pat> {
+ let name = if let FnArg::Typed(t) = fn_arg {
+ &t.pat
+ } else {
+ unreachable!()
+ };
+ name
+}
+
+fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> {
+ match *type_ {
+ Type::Path(mut type_path) => {
+ type_path.path = prepend_cuda_path_to_path(base_path, type_path.path);
+ Box::new(Type::Path(type_path))
+ }
+ Type::Ptr(mut type_ptr) => {
+ type_ptr.elem = prepend_cuda_path_to_type(base_path, type_ptr.elem);
+ Box::new(Type::Ptr(type_ptr))
+ }
+ _ => unreachable!(),
+ }
+}
+
+fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path {
+ if path.leading_colon.is_some() {
+ return path;
+ }
+ if path.segments.len() == 1 {
+ let ident = path.segments[0].ident.to_string();
+ if ident.starts_with("CU")
+ || ident.starts_with("cu")
+ || ident.starts_with("GL")
+ || ident.starts_with("EGL")
+ || ident.starts_with("Vdp")
+ || ident == "HGPUNV"
+ {
+ let mut base_path = base_path.clone();
+ base_path.segments.extend(path.segments);
+ return base_path;
+ }
+ }
+ path
+}
+
+fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item {
+ let errors = derive_state.result_variants.iter().filter_map(|const_| {
+ let prefix = "cudaError_enum_";
+ let text = &const_.ident.to_string()[prefix.len()..];
+ if text == "CUDA_SUCCESS" {
+ return None;
+ }
+ let expr = &const_.expr;
+ Some(quote! {
+ #expr => writer.write_all(#text.as_bytes()),
+ })
+ });
+ parse_quote! {
+ impl crate::format::CudaDisplay for cuda_types::CUresult {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ match self {
+ Ok(()) => writer.write_all(b"CUDA_SUCCESS"),
+ Err(err) => {
+ match err.0.get() {
+ #(#errors)*
+ err => write!(writer, "{}", err)
+ }
+ }
+ }
+ }
+ }
+ }
+}