diff options
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r-- | cuda_base/src/lib.rs | 55 |
1 files changed, 47 insertions, 8 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index c4904d9..79484fb 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -12,8 +12,8 @@ use syn::punctuated::Punctuated; use syn::visit_mut::VisitMut; use syn::{ bracketed, parse_macro_input, parse_quote, Abi, Fields, File, FnArg, ForeignItem, - ForeignItemFn, Ident, Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, - ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, UseTree, + ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, LitStr, PatType, Path, PathArguments, + PathSegment, ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, UseTree, }; const CUDA_RS: &'static str = include_str! {"cuda.rs"}; @@ -92,7 +92,7 @@ fn append_curesult(curesult_constants: Vec<syn::ItemConst>, items: &mut Vec<Item const SUCCESS: CUresult = CUresult::Ok(()); } } else { - let prefix = "cudaError_enum_CUDA_ERROR_"; + let prefix = "cudaError_enum_CUDA_"; let ident = format_ident!("{}", ident[prefix.len()..]); quote! { const #ident: CUresult = CUresult::Err(unsafe { ::core::num::NonZeroU32::new_unchecked(#expr) }); @@ -100,7 +100,7 @@ fn append_curesult(curesult_constants: Vec<syn::ItemConst>, items: &mut Vec<Item } }); items.push(parse_quote! { - trait CUresultConsts { + pub trait CUresultConsts { #(#curesult_constants)* } }); @@ -320,12 +320,44 @@ pub fn cuda_derive_display_trait(tokens: TokenStream) -> TokenStream { let input = parse_macro_input!(tokens as DeriveDisplayInput); let cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap(); let mut derive_state = DeriveDisplayState::new(input); - cuda_module + let mut main_body: proc_macro2::TokenStream = cuda_module .items .into_iter() .filter_map(|i| cuda_derive_display_trait_for_item(&mut derive_state, i)) - .collect::<proc_macro2::TokenStream>() - .into() + .collect::<proc_macro2::TokenStream>(); + main_body.extend(curesult_display_trait(&derive_state)); + main_body.into() +} + +fn curesult_display_trait(derive_state: &DeriveDisplayState) -> proc_macro2::TokenStream { + let path_prefix = &derive_state.type_path; + let trait_ = &derive_state.trait_; + let errors = derive_state.result_variants.iter().filter_map(|const_| { + let prefix = "cudaError_enum_"; + let text = &const_.ident.to_string()[prefix.len()..]; + if text == "CUDA_SUCCESS" { + return None; + } + let expr = &const_.expr; + Some(quote! { + #expr => writer.write_all(#text.as_bytes()), + }) + }); + quote! { + impl #trait_ for #path_prefix :: CUresult { + fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> { + match self { + Ok(()) => writer.write_all(b"CUDA_SUCCESS"), + Err(err) => { + match err.get() { + #(#errors)* + err => write!(writer, "{}", err) + } + } + } + } + } + } } fn cuda_derive_display_trait_for_item( @@ -337,7 +369,12 @@ fn cuda_derive_display_trait_for_item( let trait_ = &state.trait_; let trait_iter = iter::repeat(&state.trait_); match item { - Item::Const(_) => None, + Item::Const(const_) => { + if const_.ty.to_token_stream().to_string() == "cudaError_enum" { + state.result_variants.push(const_); + } + None + } Item::ForeignMod(ItemForeignMod { mut items, .. }) => match items.pop().unwrap() { ForeignItem::Fn(ForeignItemFn { sig: Signature { ident, inputs, .. }, @@ -507,6 +544,7 @@ struct DeriveDisplayState { ignore_types: FxHashSet<Ident>, ignore_fns: FxHashSet<Ident>, enums: FxHashMap<Ident, Vec<Ident>>, + result_variants: Vec<ItemConst>, } impl DeriveDisplayState { @@ -517,6 +555,7 @@ impl DeriveDisplayState { ignore_types: input.ignore_types.into_iter().collect(), ignore_fns: input.ignore_fns.into_iter().collect(), enums: Default::default(), + result_variants: Vec::new(), } } |