aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_dump/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_dump/src/lib.rs')
-rw-r--r--zluda_dump/src/lib.rs114
1 files changed, 96 insertions, 18 deletions
diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs
index 7116eee..6a7545c 100644
--- a/zluda_dump/src/lib.rs
+++ b/zluda_dump/src/lib.rs
@@ -42,6 +42,26 @@ macro_rules! extern_redirect {
};
}
+macro_rules! extern_redirect_with_post {
+ (
+ pub fn $fn_name:ident ( $($arg_id:ident: $arg_type:ty),* $(,)? ) -> $ret_type:ty ;
+ $post_fn:path ;
+ ) => {
+ #[no_mangle]
+ pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
+ let original_fn = |fn_ptr| {
+ let typed_fn = unsafe { std::mem::transmute::<_, extern "system" fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) };
+ typed_fn($( $arg_id ),*)
+ };
+ crate::handle_cuda_function_call_with_probes(
+ stringify!($fn_name),
+ || (), original_fn,
+ move |logger, state, _, cuda_result| $post_fn ( $( $arg_id ),* , logger, state, cuda_result )
+ )
+ }
+ };
+}
+
macro_rules! extern_redirect_with {
(
pub fn $fn_name:ident ( $($arg_id:ident: $arg_type:ty),* $(,)? ) -> $ret_type:ty ;
@@ -64,6 +84,7 @@ mod log;
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
mod os;
+mod trace;
pub static mut LIBCUDA_HANDLE: *mut c_void = ptr::null_mut();
pub static mut PENDING_LINKING: Option<HashMap<CUlinkState, Vec<ModuleDump>>> = None;
@@ -121,7 +142,7 @@ impl<T> LateInit<T> {
struct GlobalDelayedState {
settings: Settings,
libcuda_handle: NonNull<c_void>,
- cuda_state: CUDAStateTracker,
+ cuda_state: trace::StateTracker,
}
impl GlobalDelayedState {
@@ -140,10 +161,11 @@ impl GlobalDelayedState {
return (LateInit::Error, fn_logger);
}
};
+ let cuda_state = trace::StateTracker::new(&settings);
let delayed_state = GlobalDelayedState {
settings,
libcuda_handle,
- cuda_state: CUDAStateTracker::new(),
+ cuda_state,
};
(LateInit::Success(delayed_state), fn_logger)
}
@@ -196,22 +218,6 @@ impl Settings {
}
}
-// This struct contains all the information about current state of CUDA runtime
-// that are relevant to us: modules, kernels, linking objects, etc.
-struct CUDAStateTracker {
- modules: HashMap<CUmodule, Option<ModuleDump>>,
- module_counter: usize,
-}
-
-impl CUDAStateTracker {
- fn new() -> Self {
- CUDAStateTracker {
- modules: HashMap::new(),
- module_counter: 0,
- }
- }
-}
-
pub struct ModuleDump {
content: Rc<String>,
kernels_args: Option<HashMap<String, Vec<usize>>>,
@@ -248,6 +254,50 @@ fn handle_cuda_function_call(
cu_result
}
+fn handle_cuda_function_call_with_probes<T, PostFn>(
+ func: &'static str,
+ pre_probe: impl FnOnce() -> T,
+ original_cuda_fn: impl FnOnce(NonNull<c_void>) -> CUresult,
+ post_probe: PostFn,
+) -> CUresult
+where
+ for<'a> PostFn: FnOnce(&'a mut log::FunctionLogger, &'a mut trace::StateTracker, T, CUresult),
+{
+ let global_state_mutex = &*GLOBAL_STATE;
+ // We unwrap because there's really no sensible thing we could do,
+ // alternatively we could return a CUDA error, but I think it's fine to
+ // crash. This is a diagnostic utility, if the lock was poisoned we can't
+ // extract any useful trace or logging anyway
+ let mut global_state = &mut *global_state_mutex.lock().unwrap();
+ let (mut logger, delayed_state) = match global_state.delayed_state {
+ LateInit::Success(ref mut delayed_state) => {
+ (global_state.log_factory.get_logger(func), delayed_state)
+ }
+ // There's no libcuda to load, so we might as well panic
+ LateInit::Error => panic!(),
+ LateInit::Unitialized => {
+ let (new_delayed_state, logger) =
+ GlobalDelayedState::new(func, &mut global_state.log_factory);
+ global_state.delayed_state = new_delayed_state;
+ (logger, global_state.delayed_state.as_mut().unwrap())
+ }
+ };
+ let name = std::ffi::CString::new(func).unwrap();
+ let fn_ptr =
+ unsafe { os::get_proc_address(delayed_state.libcuda_handle.as_ptr(), name.as_c_str()) };
+ let fn_ptr = NonNull::new(fn_ptr).unwrap();
+ let pre_result = pre_probe();
+ let cu_result = original_cuda_fn(fn_ptr);
+ logger.result = Some(cu_result);
+ post_probe(
+ &mut logger,
+ &mut delayed_state.cuda_state,
+ pre_result,
+ cu_result,
+ );
+ cu_result
+}
+
#[derive(Clone, Copy)]
enum AllocLocation {
Device,
@@ -327,6 +377,34 @@ pub unsafe fn cuModuleLoadData(
result
}
+#[allow(non_snake_case)]
+pub(crate) fn cuModuleLoad_Post(
+ module: *mut CUmodule,
+ fname: *const ::std::os::raw::c_char,
+ fn_logger: &mut log::FunctionLogger,
+ state: &mut trace::StateTracker,
+ result: CUresult,
+) {
+ if result != CUresult::CUDA_SUCCESS {
+ return;
+ }
+ state.record_new_module_file(unsafe { *module }, fname, fn_logger)
+}
+
+#[allow(non_snake_case)]
+pub(crate) fn cuModuleLoadData_Post(
+ module: *mut CUmodule,
+ raw_image: *const ::std::os::raw::c_void,
+ fn_logger: &mut log::FunctionLogger,
+ state: &mut trace::StateTracker,
+ result: CUresult,
+) {
+ if result != CUresult::CUDA_SUCCESS {
+ return;
+ }
+ state.record_new_module(unsafe { *module }, raw_image, fn_logger)
+}
+
unsafe fn record_module_image_raw(module: CUmodule, raw_image: *const ::std::os::raw::c_void) {
if *(raw_image as *const u32) == 0x464c457f {
os_log!("Unsupported ELF module image: {:?}", raw_image);