aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r--cuda_base/src/lib.rs79
1 files changed, 67 insertions, 12 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index b7ebe41..c4904d9 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -11,9 +11,9 @@ use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::visit_mut::VisitMut;
use syn::{
- bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident,
- Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, ReturnType, Signature,
- Token, Type, TypeArray, TypePath, TypePtr,
+ bracketed, parse_macro_input, parse_quote, Abi, Fields, File, FnArg, ForeignItem,
+ ForeignItemFn, Ident, Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment,
+ ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, UseTree,
};
const CUDA_RS: &'static str = include_str! {"cuda.rs"};
@@ -26,22 +26,23 @@ const CUDA_RS: &'static str = include_str! {"cuda.rs"};
#[proc_macro]
pub fn cuda_type_declarations(_: TokenStream) -> TokenStream {
let mut cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
+ let mut curesult_constants = Vec::new();
cuda_module.items = cuda_module
.items
.into_iter()
.filter_map(|item| match item {
Item::ForeignMod(_) => None,
Item::Struct(mut struct_) => {
- if "CUdeviceptr_v2" == struct_.ident.to_string() {
- match &mut struct_.fields {
+ let ident_string = struct_.ident.to_string();
+ match &*ident_string {
+ "CUdeviceptr_v2" => match &mut struct_.fields {
Fields::Unnamed(ref mut fields) => {
fields.unnamed[0].ty =
absolute_path_to_mut_ptr(&["std", "os", "raw", "c_void"])
}
_ => unreachable!(),
- }
- } else if "CUuuid_st" == struct_.ident.to_string() {
- match &mut struct_.fields {
+ },
+ "CUuuid_st" => match &mut struct_.fields {
Fields::Named(ref mut fields) => match fields.named[0].ty {
Type::Array(TypeArray { ref mut elem, .. }) => {
*elem = Box::new(Type::Path(TypePath {
@@ -52,17 +53,71 @@ pub fn cuda_type_declarations(_: TokenStream) -> TokenStream {
_ => unreachable!(),
},
_ => panic!(),
- }
+ },
+ _ => {}
}
Some(Item::Struct(struct_))
}
+ Item::Const(const_) => {
+ let name = const_.ident.to_string();
+ if name.starts_with("cudaError_enum_CUDA_") {
+ curesult_constants.push(const_);
+ }
+ None
+ }
+ Item::Use(use_) => {
+ if let UseTree::Path(ref path) = use_.tree {
+ if let UseTree::Rename(ref rename) = &*path.tree {
+ if rename.rename == "CUresult" {
+ return None;
+ }
+ }
+ }
+ Some(Item::Use(use_))
+ }
i => Some(i),
})
.collect::<Vec<_>>();
+ append_curesult(curesult_constants, &mut cuda_module.items);
syn::visit_mut::visit_file_mut(&mut FixAbi, &mut cuda_module);
cuda_module.into_token_stream().into()
}
+fn append_curesult(curesult_constants: Vec<syn::ItemConst>, items: &mut Vec<Item>) {
+ let curesult_constants = curesult_constants.iter().map(|const_| {
+ let ident = const_.ident.to_string();
+ let expr = &const_.expr;
+ if ident.ends_with("CUDA_SUCCESS") {
+ quote! {
+ const SUCCESS: CUresult = CUresult::Ok(());
+ }
+ } else {
+ let prefix = "cudaError_enum_CUDA_ERROR_";
+ let ident = format_ident!("{}", ident[prefix.len()..]);
+ quote! {
+ const #ident: CUresult = CUresult::Err(unsafe { ::core::num::NonZeroU32::new_unchecked(#expr) });
+ }
+ }
+ });
+ items.push(parse_quote! {
+ trait CUresultConsts {
+ #(#curesult_constants)*
+ }
+ });
+ items.push(parse_quote! {
+ impl CUresultConsts for CUresult {}
+ });
+ items.push(parse_quote! {
+ #[must_use]
+ pub type CUresult = ::core::result::Result<(), ::core::num::NonZeroU32>;
+ });
+ items.push(parse_quote! {
+ const _: fn() = || {
+ let _ = std::mem::transmute::<CUresult, u32>;
+ };
+ });
+}
+
fn segments_to_path(path: &[&'static str]) -> Path {
let mut segments = Punctuated::new();
for ident in path {
@@ -245,7 +300,7 @@ impl Parse for FnDeclInput {
input.parse::<Token![,]>()?;
let override_fns_content;
bracketed!(override_fns_content in input);
- let override_fns = override_fns_content.parse_terminated(Ident::parse)?;
+ let override_fns = override_fns_content.parse_terminated(Ident::parse, Token![,])?;
Ok(Self {
type_path,
normal_macro,
@@ -492,11 +547,11 @@ impl Parse for DeriveDisplayInput {
input.parse::<Token![,]>()?;
let ignore_types_buffer;
bracketed!(ignore_types_buffer in input);
- let ignore_types = ignore_types_buffer.parse_terminated(Ident::parse)?;
+ let ignore_types = ignore_types_buffer.parse_terminated(Ident::parse, Token![,])?;
input.parse::<Token![,]>()?;
let ignore_fns_buffer;
bracketed!(ignore_fns_buffer in input);
- let ignore_fns = ignore_fns_buffer.parse_terminated(Ident::parse)?;
+ let ignore_fns = ignore_fns_buffer.parse_terminated(Ident::parse, Token![,])?;
Ok(Self {
type_path,
trait_,