diff options
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r-- | cuda_base/src/lib.rs | 89 |
1 files changed, 49 insertions, 40 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index ee94e71..8b804d1 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -204,7 +204,11 @@ fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path { } if path.segments.len() == 1 { let ident = path.segments[0].ident.to_string(); - if ident.starts_with("CU") || ident.starts_with("cu") { + if ident.starts_with("CU") + || ident.starts_with("cu") + || ident.starts_with("GL") + || ident == "HGPUNV" + { let mut base_path = base_path.clone(); base_path.segments.extend(path.segments); return base_path; @@ -243,7 +247,7 @@ impl Parse for FnDeclInput { // This trait accepts following parameters: // * `type_path`: path to the module with type definitions (in the module tree) // * `trait_`: name of the trait to be derived -// * `ignore_structs`: bracketed list of types to ignore +// * `ignore_types`: bracketed list of types to ignore // * `ignore_fns`: bracketed list of fns to ignore #[proc_macro] pub fn cuda_derive_display_trait(tokens: TokenStream) -> TokenStream { @@ -331,7 +335,7 @@ fn cuda_derive_display_trait_for_item( } Item::Struct(item_struct) => { let item_struct_name = item_struct.ident.to_string(); - if state.ignore_structs.contains(&item_struct.ident) { + if state.ignore_types.contains(&item_struct.ident) { return None; } if item_struct_name.ends_with("_enum") { @@ -384,43 +388,48 @@ fn cuda_derive_display_trait_for_item( }) } } - Item::Type(item_type) => match *(item_type.ty) { - Type::Ptr(_) => { - let type_ = item_type.ident; - Some(quote! { - impl #trait_ for #path_prefix :: #type_ { - fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { - write!(writer, "{:p}", *self) + Item::Type(item_type) => { + if state.ignore_types.contains(&item_type.ident) { + return None; + }; + match *(item_type.ty) { + Type::Ptr(_) => { + let type_ = item_type.ident; + Some(quote! { + impl #trait_ for #path_prefix :: #type_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + write!(writer, "{:p}", *self) + } } - } - }) - } - Type::Path(type_path) => { - if type_path.path.leading_colon.is_some() { - let option_seg = type_path.path.segments.last().unwrap(); - if option_seg.ident == "Option" { - match &option_seg.arguments { - PathArguments::AngleBracketed(generic) => match generic.args[0] { - syn::GenericArgument::Type(Type::BareFn(_)) => { - let type_ = &item_type.ident; - return Some(quote! { - impl #trait_ for #path_prefix :: #type_ { - fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { - write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) }) + }) + } + Type::Path(type_path) => { + if type_path.path.leading_colon.is_some() { + let option_seg = type_path.path.segments.last().unwrap(); + if option_seg.ident == "Option" { + match &option_seg.arguments { + PathArguments::AngleBracketed(generic) => match generic.args[0] { + syn::GenericArgument::Type(Type::BareFn(_)) => { + let type_ = &item_type.ident; + return Some(quote! { + impl #trait_ for #path_prefix :: #type_ { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) }) + } } - } - }); - } + }); + } + _ => unreachable!(), + }, _ => unreachable!(), - }, - _ => unreachable!(), + } } } + None } - None + _ => unreachable!(), } - _ => unreachable!(), - }, + } Item::Union(_) => None, Item::Use(_) => None, _ => unreachable!(), @@ -430,7 +439,7 @@ fn cuda_derive_display_trait_for_item( struct DeriveDisplayState { type_path: Path, trait_: Path, - ignore_structs: FxHashSet<Ident>, + ignore_types: FxHashSet<Ident>, ignore_fns: FxHashSet<Ident>, enums: FxHashMap<Ident, Vec<Ident>>, } @@ -440,7 +449,7 @@ impl DeriveDisplayState { DeriveDisplayState { type_path: input.type_path, trait_: input.trait_, - ignore_structs: input.ignore_structs.into_iter().collect(), + ignore_types: input.ignore_types.into_iter().collect(), ignore_fns: input.ignore_fns.into_iter().collect(), enums: Default::default(), } @@ -461,7 +470,7 @@ impl DeriveDisplayState { struct DeriveDisplayInput { type_path: Path, trait_: Path, - ignore_structs: Punctuated<Ident, Token![,]>, + ignore_types: Punctuated<Ident, Token![,]>, ignore_fns: Punctuated<Ident, Token![,]>, } @@ -471,9 +480,9 @@ impl Parse for DeriveDisplayInput { input.parse::<Token![,]>()?; let trait_ = input.parse::<Path>()?; input.parse::<Token![,]>()?; - let ignore_structs_buffer; - bracketed!(ignore_structs_buffer in input); - let ignore_structs = ignore_structs_buffer.parse_terminated(Ident::parse)?; + let ignore_types_buffer; + bracketed!(ignore_types_buffer in input); + let ignore_types = ignore_types_buffer.parse_terminated(Ident::parse)?; input.parse::<Token![,]>()?; let ignore_fns_buffer; bracketed!(ignore_fns_buffer in input); @@ -481,7 +490,7 @@ impl Parse for DeriveDisplayInput { Ok(Self { type_path, trait_, - ignore_structs, + ignore_types, ignore_fns, }) } |