diff options
Diffstat (limited to 'zluda_dump/src/lib.rs')
-rw-r--r-- | zluda_dump/src/lib.rs | 114 |
1 files changed, 96 insertions, 18 deletions
diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index 7116eee..6a7545c 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -42,6 +42,26 @@ macro_rules! extern_redirect { }; } +macro_rules! extern_redirect_with_post { + ( + pub fn $fn_name:ident ( $($arg_id:ident: $arg_type:ty),* $(,)? ) -> $ret_type:ty ; + $post_fn:path ; + ) => { + #[no_mangle] + pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { + let original_fn = |fn_ptr| { + let typed_fn = unsafe { std::mem::transmute::<_, extern "system" fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) }; + typed_fn($( $arg_id ),*) + }; + crate::handle_cuda_function_call_with_probes( + stringify!($fn_name), + || (), original_fn, + move |logger, state, _, cuda_result| $post_fn ( $( $arg_id ),* , logger, state, cuda_result ) + ) + } + }; +} + macro_rules! extern_redirect_with { ( pub fn $fn_name:ident ( $($arg_id:ident: $arg_type:ty),* $(,)? ) -> $ret_type:ty ; @@ -64,6 +84,7 @@ mod log; #[cfg_attr(windows, path = "os_win.rs")] #[cfg_attr(not(windows), path = "os_unix.rs")] mod os; +mod trace; pub static mut LIBCUDA_HANDLE: *mut c_void = ptr::null_mut(); pub static mut PENDING_LINKING: Option<HashMap<CUlinkState, Vec<ModuleDump>>> = None; @@ -121,7 +142,7 @@ impl<T> LateInit<T> { struct GlobalDelayedState { settings: Settings, libcuda_handle: NonNull<c_void>, - cuda_state: CUDAStateTracker, + cuda_state: trace::StateTracker, } impl GlobalDelayedState { @@ -140,10 +161,11 @@ impl GlobalDelayedState { return (LateInit::Error, fn_logger); } }; + let cuda_state = trace::StateTracker::new(&settings); let delayed_state = GlobalDelayedState { settings, libcuda_handle, - cuda_state: CUDAStateTracker::new(), + cuda_state, }; (LateInit::Success(delayed_state), fn_logger) } @@ -196,22 +218,6 @@ impl Settings { } } -// This struct contains all the information about current state of CUDA runtime -// that are relevant to us: modules, kernels, linking objects, etc. -struct CUDAStateTracker { - modules: HashMap<CUmodule, Option<ModuleDump>>, - module_counter: usize, -} - -impl CUDAStateTracker { - fn new() -> Self { - CUDAStateTracker { - modules: HashMap::new(), - module_counter: 0, - } - } -} - pub struct ModuleDump { content: Rc<String>, kernels_args: Option<HashMap<String, Vec<usize>>>, @@ -248,6 +254,50 @@ fn handle_cuda_function_call( cu_result } +fn handle_cuda_function_call_with_probes<T, PostFn>( + func: &'static str, + pre_probe: impl FnOnce() -> T, + original_cuda_fn: impl FnOnce(NonNull<c_void>) -> CUresult, + post_probe: PostFn, +) -> CUresult +where + for<'a> PostFn: FnOnce(&'a mut log::FunctionLogger, &'a mut trace::StateTracker, T, CUresult), +{ + let global_state_mutex = &*GLOBAL_STATE; + // We unwrap because there's really no sensible thing we could do, + // alternatively we could return a CUDA error, but I think it's fine to + // crash. This is a diagnostic utility, if the lock was poisoned we can't + // extract any useful trace or logging anyway + let mut global_state = &mut *global_state_mutex.lock().unwrap(); + let (mut logger, delayed_state) = match global_state.delayed_state { + LateInit::Success(ref mut delayed_state) => { + (global_state.log_factory.get_logger(func), delayed_state) + } + // There's no libcuda to load, so we might as well panic + LateInit::Error => panic!(), + LateInit::Unitialized => { + let (new_delayed_state, logger) = + GlobalDelayedState::new(func, &mut global_state.log_factory); + global_state.delayed_state = new_delayed_state; + (logger, global_state.delayed_state.as_mut().unwrap()) + } + }; + let name = std::ffi::CString::new(func).unwrap(); + let fn_ptr = + unsafe { os::get_proc_address(delayed_state.libcuda_handle.as_ptr(), name.as_c_str()) }; + let fn_ptr = NonNull::new(fn_ptr).unwrap(); + let pre_result = pre_probe(); + let cu_result = original_cuda_fn(fn_ptr); + logger.result = Some(cu_result); + post_probe( + &mut logger, + &mut delayed_state.cuda_state, + pre_result, + cu_result, + ); + cu_result +} + #[derive(Clone, Copy)] enum AllocLocation { Device, @@ -327,6 +377,34 @@ pub unsafe fn cuModuleLoadData( result } +#[allow(non_snake_case)] +pub(crate) fn cuModuleLoad_Post( + module: *mut CUmodule, + fname: *const ::std::os::raw::c_char, + fn_logger: &mut log::FunctionLogger, + state: &mut trace::StateTracker, + result: CUresult, +) { + if result != CUresult::CUDA_SUCCESS { + return; + } + state.record_new_module_file(unsafe { *module }, fname, fn_logger) +} + +#[allow(non_snake_case)] +pub(crate) fn cuModuleLoadData_Post( + module: *mut CUmodule, + raw_image: *const ::std::os::raw::c_void, + fn_logger: &mut log::FunctionLogger, + state: &mut trace::StateTracker, + result: CUresult, +) { + if result != CUresult::CUDA_SUCCESS { + return; + } + state.record_new_module(unsafe { *module }, raw_image, fn_logger) +} + unsafe fn record_module_image_raw(module: CUmodule, raw_image: *const ::std::os::raw::c_void) { if *(raw_image as *const u32) == 0x464c457f { os_log!("Unsupported ELF module image: {:?}", raw_image); |