From 89e72e4e95858e329276b1feb080a847306e02d2 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 17 May 2021 01:25:38 +0200 Subject: Handle even more export table functions --- zluda_dump/src/lib.rs | 43 +++++++++++++++++++++++++++++++++---------- zluda_dump/src/os_unix.rs | 2 +- zluda_dump/src/os_win.rs | 4 ++-- 3 files changed, 36 insertions(+), 13 deletions(-) (limited to 'zluda_dump') diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index eecd573..5b00844 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -685,7 +685,7 @@ static mut ORIGINAL_GET_MODULE_FROM_CUBIN_EXT: Option< ) -> CUresult, > = None; -unsafe extern "stdcall" fn report_unknown_export_table_call( +unsafe extern "system" fn report_unknown_export_table_call( export_table: *const CUuuid, idx: usize, ) { @@ -699,22 +699,27 @@ pub unsafe fn cuGetExportTable( pExportTableId: *const CUuuid, cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult, ) -> CUresult { + if ppExportTable == ptr::null_mut() || pExportTableId == ptr::null() { + return CUresult::CUDA_ERROR_INVALID_VALUE; + } let guid = (*pExportTableId).bytes; os_log!("Requested export table id: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15]); - let result = cont(ppExportTable, pExportTableId); - if result == CUresult::CUDA_SUCCESS { - override_export_table(ppExportTable, pExportTableId); - } - result + override_export_table(ppExportTable, pExportTableId, cont) } unsafe fn override_export_table( export_table_ptr: *mut *const ::std::os::raw::c_void, export_table_id: *const CUuuid, -) { + cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult, +) -> CUresult { let overrides_map = OVERRIDEN_INTERFACE_VTABLES.get_or_insert_with(|| HashMap::new()); - if overrides_map.contains_key(&*export_table_id) { - return; + if let Some(override_table) = overrides_map.get(&*export_table_id) { + *export_table_ptr = override_table.as_ptr() as *const _; + return CUresult::CUDA_SUCCESS; + } + let base_result = cont(export_table_ptr, export_table_id); + if base_result != CUresult::CUDA_SUCCESS { + return base_result; } let export_table = (*export_table_ptr) as *mut *const c_void; let boxed_guid = Box::new(*export_table_id); @@ -745,6 +750,7 @@ unsafe fn override_export_table( } *export_table_ptr = override_table.as_ptr() as *const _; overrides_map.insert(boxed_guid, override_table); + CUresult::CUDA_SUCCESS } const TOOLS_RUNTIME_CALLBACK_HOOKS_GUID: CUuuid = CUuuid { @@ -761,6 +767,20 @@ const CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID: CUuuid = CUuuid { ], }; +const CTX_CREATE_BYPASS_GUID: CUuuid = CUuuid { + bytes: [ + 0x0C, 0xA5, 0x0B, 0x8C, 0x10, 0x04, 0x92, 0x9A, 0x89, 0xA7, 0xD0, 0xDF, 0x10, 0xE7, 0x72, + 0x86, + ], +}; + +const HEAP_ACCESS_GUID: CUuuid = CUuuid { + bytes: [ + 0x19, 0x5B, 0xCB, 0xF4, 0xD6, 0x7D, 0x02, 0x4A, 0xAC, 0xC5, 0x1D, 0x29, 0xCE, 0xA6, 0x31, + 0xAE, + ], +}; + unsafe fn get_export_override_fn( original_fn: *const c_void, guid: *const CUuuid, @@ -773,7 +793,10 @@ unsafe fn get_export_override_fn( | (CUDART_INTERFACE_GUID, 7) | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 0) | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 1) - | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 2) => original_fn, + | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 2) + | (CTX_CREATE_BYPASS_GUID, 1) + | (HEAP_ACCESS_GUID, 1) + | (HEAP_ACCESS_GUID, 2) => original_fn, (CUDART_INTERFACE_GUID, 1) => { ORIGINAL_GET_MODULE_FROM_CUBIN = mem::transmute(original_fn); get_module_from_cubin as *const _ diff --git a/zluda_dump/src/os_unix.rs b/zluda_dump/src/os_unix.rs index 2cf8dad..74543dd 100644 --- a/zluda_dump/src/os_unix.rs +++ b/zluda_dump/src/os_unix.rs @@ -33,7 +33,7 @@ macro_rules! os_log { #[cfg(target_arch = "x86_64")] pub fn get_thunk( original_fn: *const c_void, - report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize), + report_fn: unsafe extern "system" fn(*const CUuuid, usize), guid: *const CUuuid, idx: usize, ) -> *const c_void { diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index 55b69da..1617aa5 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -103,7 +103,7 @@ pub fn __log_impl(s: String) { #[cfg(target_arch = "x86")] pub fn get_thunk( original_fn: *const c_void, - report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize), + report_fn: unsafe extern "system" fn(*const CUuuid, usize), guid: *const CUuuid, idx: usize, ) -> *const c_void { @@ -130,7 +130,7 @@ pub fn get_thunk( #[cfg(target_arch = "x86_64")] pub fn get_thunk( original_fn: *const c_void, - report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize), + report_fn: unsafe extern "system" fn(*const CUuuid, usize), guid: *const CUuuid, idx: usize, ) -> *const c_void { -- cgit v1.2.3