aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-11-14 20:07:58 +0000
committerAndrzej Janik <[email protected]>2024-11-14 20:07:58 +0000
commitc6e8c6a48a2d7463f2d55ecf537c4ab4477077f2 (patch)
treeaa8076d02745b3412b4372ba7c96b5156d260f91
parentaaf8356a545d284b43d2e89d9d08e1c8c53ad00e (diff)
downloadZLUDA-c6e8c6a48a2d7463f2d55ecf537c4ab4477077f2.tar.gz
ZLUDA-c6e8c6a48a2d7463f2d55ecf537c4ab4477077f2.zip
Fix zluda_dump
-rw-r--r--cuda_base/src/lib.rs55
-rw-r--r--zluda_dump/src/dark_api.rs6
-rw-r--r--zluda_dump/src/lib.rs19
-rw-r--r--zluda_dump/src/side_by_side.rs2
4 files changed, 63 insertions, 19 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index c4904d9..79484fb 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -12,8 +12,8 @@ use syn::punctuated::Punctuated;
use syn::visit_mut::VisitMut;
use syn::{
bracketed, parse_macro_input, parse_quote, Abi, Fields, File, FnArg, ForeignItem,
- ForeignItemFn, Ident, Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment,
- ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, UseTree,
+ ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, LitStr, PatType, Path, PathArguments,
+ PathSegment, ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, UseTree,
};
const CUDA_RS: &'static str = include_str! {"cuda.rs"};
@@ -92,7 +92,7 @@ fn append_curesult(curesult_constants: Vec<syn::ItemConst>, items: &mut Vec<Item
const SUCCESS: CUresult = CUresult::Ok(());
}
} else {
- let prefix = "cudaError_enum_CUDA_ERROR_";
+ let prefix = "cudaError_enum_CUDA_";
let ident = format_ident!("{}", ident[prefix.len()..]);
quote! {
const #ident: CUresult = CUresult::Err(unsafe { ::core::num::NonZeroU32::new_unchecked(#expr) });
@@ -100,7 +100,7 @@ fn append_curesult(curesult_constants: Vec<syn::ItemConst>, items: &mut Vec<Item
}
});
items.push(parse_quote! {
- trait CUresultConsts {
+ pub trait CUresultConsts {
#(#curesult_constants)*
}
});
@@ -320,12 +320,44 @@ pub fn cuda_derive_display_trait(tokens: TokenStream) -> TokenStream {
let input = parse_macro_input!(tokens as DeriveDisplayInput);
let cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
let mut derive_state = DeriveDisplayState::new(input);
- cuda_module
+ let mut main_body: proc_macro2::TokenStream = cuda_module
.items
.into_iter()
.filter_map(|i| cuda_derive_display_trait_for_item(&mut derive_state, i))
- .collect::<proc_macro2::TokenStream>()
- .into()
+ .collect::<proc_macro2::TokenStream>();
+ main_body.extend(curesult_display_trait(&derive_state));
+ main_body.into()
+}
+
+fn curesult_display_trait(derive_state: &DeriveDisplayState) -> proc_macro2::TokenStream {
+ let path_prefix = &derive_state.type_path;
+ let trait_ = &derive_state.trait_;
+ let errors = derive_state.result_variants.iter().filter_map(|const_| {
+ let prefix = "cudaError_enum_";
+ let text = &const_.ident.to_string()[prefix.len()..];
+ if text == "CUDA_SUCCESS" {
+ return None;
+ }
+ let expr = &const_.expr;
+ Some(quote! {
+ #expr => writer.write_all(#text.as_bytes()),
+ })
+ });
+ quote! {
+ impl #trait_ for #path_prefix :: CUresult {
+ fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ match self {
+ Ok(()) => writer.write_all(b"CUDA_SUCCESS"),
+ Err(err) => {
+ match err.get() {
+ #(#errors)*
+ err => write!(writer, "{}", err)
+ }
+ }
+ }
+ }
+ }
+ }
}
fn cuda_derive_display_trait_for_item(
@@ -337,7 +369,12 @@ fn cuda_derive_display_trait_for_item(
let trait_ = &state.trait_;
let trait_iter = iter::repeat(&state.trait_);
match item {
- Item::Const(_) => None,
+ Item::Const(const_) => {
+ if const_.ty.to_token_stream().to_string() == "cudaError_enum" {
+ state.result_variants.push(const_);
+ }
+ None
+ }
Item::ForeignMod(ItemForeignMod { mut items, .. }) => match items.pop().unwrap() {
ForeignItem::Fn(ForeignItemFn {
sig: Signature { ident, inputs, .. },
@@ -507,6 +544,7 @@ struct DeriveDisplayState {
ignore_types: FxHashSet<Ident>,
ignore_fns: FxHashSet<Ident>,
enums: FxHashMap<Ident, Vec<Ident>>,
+ result_variants: Vec<ItemConst>,
}
impl DeriveDisplayState {
@@ -517,6 +555,7 @@ impl DeriveDisplayState {
ignore_types: input.ignore_types.into_iter().collect(),
ignore_fns: input.ignore_fns.into_iter().collect(),
enums: Default::default(),
+ result_variants: Vec::new(),
}
}
diff --git a/zluda_dump/src/dark_api.rs b/zluda_dump/src/dark_api.rs
index 8b1cd79..623f96f 100644
--- a/zluda_dump/src/dark_api.rs
+++ b/zluda_dump/src/dark_api.rs
@@ -28,6 +28,7 @@ impl Hash for CUuuidWrapper {
}
}
+#[allow(improper_ctypes_definitions)]
pub(crate) struct OriginalExports {
original_get_module_from_cubin: Option<
unsafe extern "system" fn(
@@ -356,6 +357,7 @@ unsafe fn record_submodules_from_fatbin(
);
}
+#[allow(improper_ctypes_definitions)]
unsafe extern "system" fn get_module_from_cubin(
module: *mut CUmodule,
fatbinc_wrapper: *const FatbincWrapper,
@@ -388,6 +390,7 @@ unsafe extern "system" fn get_module_from_cubin(
)
}
+#[allow(improper_ctypes_definitions)]
unsafe extern "system" fn get_module_from_cubin_ext1(
module: *mut CUmodule,
fatbinc_wrapper: *const FatbincWrapper,
@@ -451,6 +454,7 @@ unsafe extern "system" fn get_module_from_cubin_ext1(
)
}
+#[allow(improper_ctypes_definitions)]
unsafe extern "system" fn get_module_from_cubin_ext2(
fatbin_header: *const FatbinHeader,
module: *mut CUmodule,
@@ -508,7 +512,7 @@ unsafe extern "system" fn get_module_from_cubin_ext2(
.original_get_module_from_cubin_ext2
.unwrap()(fatbin_header, module, ptr1, ptr2, _unknown);
fn_logger.result = Some(result);
- if result != CUresult::CUDA_SUCCESS {
+ if result.is_err() {
return result;
}
record_submodules_from_fatbin(
diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs
index f3801b8..8eb1544 100644
--- a/zluda_dump/src/lib.rs
+++ b/zluda_dump/src/lib.rs
@@ -2,10 +2,7 @@ use cuda_types::*;
use paste::paste;
use side_by_side::CudaDynamicFns;
use std::io;
-use std::{
- collections::HashMap, env, error::Error, ffi::c_void, fs, path::PathBuf, ptr::NonNull, rc::Rc,
- sync::Mutex,
-};
+use std::{collections::HashMap, env, error::Error, fs, path::PathBuf, rc::Rc, sync::Mutex};
#[macro_use]
extern crate lazy_static;
@@ -15,6 +12,7 @@ macro_rules! extern_redirect {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => {
$(
#[no_mangle]
+ #[allow(improper_ctypes_definitions)]
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
let original_fn = |dynamic_fns: &mut crate::side_by_side::CudaDynamicFns| {
dynamic_fns.$fn_name($( $arg_id ),*)
@@ -35,7 +33,8 @@ macro_rules! extern_redirect_with_post {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => {
$(
#[no_mangle]
- pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
+ #[allow(improper_ctypes_definitions)]
+ pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
let original_fn = |dynamic_fns: &mut crate::side_by_side::CudaDynamicFns| {
dynamic_fns.$fn_name($( $arg_id ),*)
};
@@ -323,7 +322,7 @@ where
logger.log(log::LogEntry::ErrorBox(
format!("No function {} in the underlying CUDA library", func).into(),
));
- CUresult::CUDA_ERROR_UNKNOWN
+ CUresult::ERROR_UNKNOWN
}
};
logger.result = maybe_cu_result;
@@ -357,7 +356,7 @@ pub(crate) fn cuModuleLoad_Post(
state: &mut trace::StateTracker,
result: CUresult,
) {
- if result != CUresult::CUDA_SUCCESS {
+ if result.is_err() {
return;
}
state.record_new_module_file(unsafe { *module }, fname, fn_logger)
@@ -371,7 +370,7 @@ pub(crate) fn cuModuleLoadData_Post(
state: &mut trace::StateTracker,
result: CUresult,
) {
- if result != CUresult::CUDA_SUCCESS {
+ if result.is_err() {
return;
}
state.record_new_module(unsafe { *module }, raw_image, fn_logger)
@@ -399,7 +398,7 @@ pub(crate) fn cuGetExportTable_Post(
state: &mut trace::StateTracker,
result: CUresult,
) {
- if result != CUresult::CUDA_SUCCESS {
+ if result.is_err() {
return;
}
dark_api::override_export_table(ppExportTable, pExportTableId, state)
@@ -449,7 +448,7 @@ pub(crate) fn cuModuleLoadFatBinary_Post(
_state: &mut trace::StateTracker,
result: CUresult,
) {
- if result == CUresult::CUDA_SUCCESS {
+ if result.is_ok() {
panic!()
}
}
diff --git a/zluda_dump/src/side_by_side.rs b/zluda_dump/src/side_by_side.rs
index 33954b8..076dc68 100644
--- a/zluda_dump/src/side_by_side.rs
+++ b/zluda_dump/src/side_by_side.rs
@@ -58,6 +58,8 @@ impl CudaDynamicFns {
macro_rules! emit_cuda_fn_table {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => {
#[derive(Default)]
+ #[allow(improper_ctypes)]
+ #[allow(improper_ctypes_definitions)]
struct CudaFnTable {
$($fn_name: DynamicFn<extern $abi fn ( $($arg_id : $arg_type),* ) -> $ret_type>),*
}