diff options
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r-- | cuda_base/src/lib.rs | 79 |
1 files changed, 67 insertions, 12 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index b7ebe41..c4904d9 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -11,9 +11,9 @@ use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::visit_mut::VisitMut; use syn::{ - bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident, - Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, ReturnType, Signature, - Token, Type, TypeArray, TypePath, TypePtr, + bracketed, parse_macro_input, parse_quote, Abi, Fields, File, FnArg, ForeignItem, + ForeignItemFn, Ident, Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, + ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, UseTree, }; const CUDA_RS: &'static str = include_str! {"cuda.rs"}; @@ -26,22 +26,23 @@ const CUDA_RS: &'static str = include_str! {"cuda.rs"}; #[proc_macro] pub fn cuda_type_declarations(_: TokenStream) -> TokenStream { let mut cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap(); + let mut curesult_constants = Vec::new(); cuda_module.items = cuda_module .items .into_iter() .filter_map(|item| match item { Item::ForeignMod(_) => None, Item::Struct(mut struct_) => { - if "CUdeviceptr_v2" == struct_.ident.to_string() { - match &mut struct_.fields { + let ident_string = struct_.ident.to_string(); + match &*ident_string { + "CUdeviceptr_v2" => match &mut struct_.fields { Fields::Unnamed(ref mut fields) => { fields.unnamed[0].ty = absolute_path_to_mut_ptr(&["std", "os", "raw", "c_void"]) } _ => unreachable!(), - } - } else if "CUuuid_st" == struct_.ident.to_string() { - match &mut struct_.fields { + }, + "CUuuid_st" => match &mut struct_.fields { Fields::Named(ref mut fields) => match fields.named[0].ty { Type::Array(TypeArray { ref mut elem, .. }) => { *elem = Box::new(Type::Path(TypePath { @@ -52,17 +53,71 @@ pub fn cuda_type_declarations(_: TokenStream) -> TokenStream { _ => unreachable!(), }, _ => panic!(), - } + }, + _ => {} } Some(Item::Struct(struct_)) } + Item::Const(const_) => { + let name = const_.ident.to_string(); + if name.starts_with("cudaError_enum_CUDA_") { + curesult_constants.push(const_); + } + None + } + Item::Use(use_) => { + if let UseTree::Path(ref path) = use_.tree { + if let UseTree::Rename(ref rename) = &*path.tree { + if rename.rename == "CUresult" { + return None; + } + } + } + Some(Item::Use(use_)) + } i => Some(i), }) .collect::<Vec<_>>(); + append_curesult(curesult_constants, &mut cuda_module.items); syn::visit_mut::visit_file_mut(&mut FixAbi, &mut cuda_module); cuda_module.into_token_stream().into() } +fn append_curesult(curesult_constants: Vec<syn::ItemConst>, items: &mut Vec<Item>) { + let curesult_constants = curesult_constants.iter().map(|const_| { + let ident = const_.ident.to_string(); + let expr = &const_.expr; + if ident.ends_with("CUDA_SUCCESS") { + quote! { + const SUCCESS: CUresult = CUresult::Ok(()); + } + } else { + let prefix = "cudaError_enum_CUDA_ERROR_"; + let ident = format_ident!("{}", ident[prefix.len()..]); + quote! { + const #ident: CUresult = CUresult::Err(unsafe { ::core::num::NonZeroU32::new_unchecked(#expr) }); + } + } + }); + items.push(parse_quote! { + trait CUresultConsts { + #(#curesult_constants)* + } + }); + items.push(parse_quote! { + impl CUresultConsts for CUresult {} + }); + items.push(parse_quote! { + #[must_use] + pub type CUresult = ::core::result::Result<(), ::core::num::NonZeroU32>; + }); + items.push(parse_quote! { + const _: fn() = || { + let _ = std::mem::transmute::<CUresult, u32>; + }; + }); +} + fn segments_to_path(path: &[&'static str]) -> Path { let mut segments = Punctuated::new(); for ident in path { @@ -245,7 +300,7 @@ impl Parse for FnDeclInput { input.parse::<Token![,]>()?; let override_fns_content; bracketed!(override_fns_content in input); - let override_fns = override_fns_content.parse_terminated(Ident::parse)?; + let override_fns = override_fns_content.parse_terminated(Ident::parse, Token![,])?; Ok(Self { type_path, normal_macro, @@ -492,11 +547,11 @@ impl Parse for DeriveDisplayInput { input.parse::<Token![,]>()?; let ignore_types_buffer; bracketed!(ignore_types_buffer in input); - let ignore_types = ignore_types_buffer.parse_terminated(Ident::parse)?; + let ignore_types = ignore_types_buffer.parse_terminated(Ident::parse, Token![,])?; input.parse::<Token![,]>()?; let ignore_fns_buffer; bracketed!(ignore_fns_buffer in input); - let ignore_fns = ignore_fns_buffer.parse_terminated(Ident::parse)?; + let ignore_fns = ignore_fns_buffer.parse_terminated(Ident::parse, Token![,])?; Ok(Self { type_path, trait_, |