aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base/src/lib.rs
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2022-01-26 11:32:20 +0100
committerAndrzej Janik <[email protected]>2022-01-26 11:32:20 +0100
commit07aa1103aae2849116f6e8df745e222d3d57e031 (patch)
tree13861a52ef8dafc479484964daa5d600db53d86c /cuda_base/src/lib.rs
parent6f76c8b34c2132b98491d60b5120d60fb2fc80e1 (diff)
downloadZLUDA-07aa1103aae2849116f6e8df745e222d3d57e031.tar.gz
ZLUDA-07aa1103aae2849116f6e8df745e222d3d57e031.zip
Add OGL interop to cuda proc macros
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r--cuda_base/src/lib.rs89
1 files changed, 49 insertions, 40 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index ee94e71..8b804d1 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -204,7 +204,11 @@ fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path {
}
if path.segments.len() == 1 {
let ident = path.segments[0].ident.to_string();
- if ident.starts_with("CU") || ident.starts_with("cu") {
+ if ident.starts_with("CU")
+ || ident.starts_with("cu")
+ || ident.starts_with("GL")
+ || ident == "HGPUNV"
+ {
let mut base_path = base_path.clone();
base_path.segments.extend(path.segments);
return base_path;
@@ -243,7 +247,7 @@ impl Parse for FnDeclInput {
// This trait accepts following parameters:
// * `type_path`: path to the module with type definitions (in the module tree)
// * `trait_`: name of the trait to be derived
-// * `ignore_structs`: bracketed list of types to ignore
+// * `ignore_types`: bracketed list of types to ignore
// * `ignore_fns`: bracketed list of fns to ignore
#[proc_macro]
pub fn cuda_derive_display_trait(tokens: TokenStream) -> TokenStream {
@@ -331,7 +335,7 @@ fn cuda_derive_display_trait_for_item(
}
Item::Struct(item_struct) => {
let item_struct_name = item_struct.ident.to_string();
- if state.ignore_structs.contains(&item_struct.ident) {
+ if state.ignore_types.contains(&item_struct.ident) {
return None;
}
if item_struct_name.ends_with("_enum") {
@@ -384,43 +388,48 @@ fn cuda_derive_display_trait_for_item(
})
}
}
- Item::Type(item_type) => match *(item_type.ty) {
- Type::Ptr(_) => {
- let type_ = item_type.ident;
- Some(quote! {
- impl #trait_ for #path_prefix :: #type_ {
- fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
- write!(writer, "{:p}", *self)
+ Item::Type(item_type) => {
+ if state.ignore_types.contains(&item_type.ident) {
+ return None;
+ };
+ match *(item_type.ty) {
+ Type::Ptr(_) => {
+ let type_ = item_type.ident;
+ Some(quote! {
+ impl #trait_ for #path_prefix :: #type_ {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ write!(writer, "{:p}", *self)
+ }
}
- }
- })
- }
- Type::Path(type_path) => {
- if type_path.path.leading_colon.is_some() {
- let option_seg = type_path.path.segments.last().unwrap();
- if option_seg.ident == "Option" {
- match &option_seg.arguments {
- PathArguments::AngleBracketed(generic) => match generic.args[0] {
- syn::GenericArgument::Type(Type::BareFn(_)) => {
- let type_ = &item_type.ident;
- return Some(quote! {
- impl #trait_ for #path_prefix :: #type_ {
- fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
- write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) })
+ })
+ }
+ Type::Path(type_path) => {
+ if type_path.path.leading_colon.is_some() {
+ let option_seg = type_path.path.segments.last().unwrap();
+ if option_seg.ident == "Option" {
+ match &option_seg.arguments {
+ PathArguments::AngleBracketed(generic) => match generic.args[0] {
+ syn::GenericArgument::Type(Type::BareFn(_)) => {
+ let type_ = &item_type.ident;
+ return Some(quote! {
+ impl #trait_ for #path_prefix :: #type_ {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) })
+ }
}
- }
- });
- }
+ });
+ }
+ _ => unreachable!(),
+ },
_ => unreachable!(),
- },
- _ => unreachable!(),
+ }
}
}
+ None
}
- None
+ _ => unreachable!(),
}
- _ => unreachable!(),
- },
+ }
Item::Union(_) => None,
Item::Use(_) => None,
_ => unreachable!(),
@@ -430,7 +439,7 @@ fn cuda_derive_display_trait_for_item(
struct DeriveDisplayState {
type_path: Path,
trait_: Path,
- ignore_structs: FxHashSet<Ident>,
+ ignore_types: FxHashSet<Ident>,
ignore_fns: FxHashSet<Ident>,
enums: FxHashMap<Ident, Vec<Ident>>,
}
@@ -440,7 +449,7 @@ impl DeriveDisplayState {
DeriveDisplayState {
type_path: input.type_path,
trait_: input.trait_,
- ignore_structs: input.ignore_structs.into_iter().collect(),
+ ignore_types: input.ignore_types.into_iter().collect(),
ignore_fns: input.ignore_fns.into_iter().collect(),
enums: Default::default(),
}
@@ -461,7 +470,7 @@ impl DeriveDisplayState {
struct DeriveDisplayInput {
type_path: Path,
trait_: Path,
- ignore_structs: Punctuated<Ident, Token![,]>,
+ ignore_types: Punctuated<Ident, Token![,]>,
ignore_fns: Punctuated<Ident, Token![,]>,
}
@@ -471,9 +480,9 @@ impl Parse for DeriveDisplayInput {
input.parse::<Token![,]>()?;
let trait_ = input.parse::<Path>()?;
input.parse::<Token![,]>()?;
- let ignore_structs_buffer;
- bracketed!(ignore_structs_buffer in input);
- let ignore_structs = ignore_structs_buffer.parse_terminated(Ident::parse)?;
+ let ignore_types_buffer;
+ bracketed!(ignore_types_buffer in input);
+ let ignore_types = ignore_types_buffer.parse_terminated(Ident::parse)?;
input.parse::<Token![,]>()?;
let ignore_fns_buffer;
bracketed!(ignore_fns_buffer in input);
@@ -481,7 +490,7 @@ impl Parse for DeriveDisplayInput {
Ok(Self {
type_path,
trait_,
- ignore_structs,
+ ignore_types,
ignore_fns,
})
}