aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_dump
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-05-17 01:25:38 +0200
committerAndrzej Janik <[email protected]>2021-05-17 01:25:38 +0200
commit89e72e4e95858e329276b1feb080a847306e02d2 (patch)
tree763012d4cc6b6892596a71e240bb0bdea033f6c1 /zluda_dump
parentdca4c5bd21d816bb72c9a2772dd444a04717630a (diff)
downloadZLUDA-89e72e4e95858e329276b1feb080a847306e02d2.tar.gz
ZLUDA-89e72e4e95858e329276b1feb080a847306e02d2.zip
Handle even more export table functions
Diffstat (limited to 'zluda_dump')
-rw-r--r--zluda_dump/src/lib.rs43
-rw-r--r--zluda_dump/src/os_unix.rs2
-rw-r--r--zluda_dump/src/os_win.rs4
3 files changed, 36 insertions, 13 deletions
diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs
index eecd573..5b00844 100644
--- a/zluda_dump/src/lib.rs
+++ b/zluda_dump/src/lib.rs
@@ -685,7 +685,7 @@ static mut ORIGINAL_GET_MODULE_FROM_CUBIN_EXT: Option<
) -> CUresult,
> = None;
-unsafe extern "stdcall" fn report_unknown_export_table_call(
+unsafe extern "system" fn report_unknown_export_table_call(
export_table: *const CUuuid,
idx: usize,
) {
@@ -699,22 +699,27 @@ pub unsafe fn cuGetExportTable(
pExportTableId: *const CUuuid,
cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult,
) -> CUresult {
+ if ppExportTable == ptr::null_mut() || pExportTableId == ptr::null() {
+ return CUresult::CUDA_ERROR_INVALID_VALUE;
+ }
let guid = (*pExportTableId).bytes;
os_log!("Requested export table id: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15]);
- let result = cont(ppExportTable, pExportTableId);
- if result == CUresult::CUDA_SUCCESS {
- override_export_table(ppExportTable, pExportTableId);
- }
- result
+ override_export_table(ppExportTable, pExportTableId, cont)
}
unsafe fn override_export_table(
export_table_ptr: *mut *const ::std::os::raw::c_void,
export_table_id: *const CUuuid,
-) {
+ cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult,
+) -> CUresult {
let overrides_map = OVERRIDEN_INTERFACE_VTABLES.get_or_insert_with(|| HashMap::new());
- if overrides_map.contains_key(&*export_table_id) {
- return;
+ if let Some(override_table) = overrides_map.get(&*export_table_id) {
+ *export_table_ptr = override_table.as_ptr() as *const _;
+ return CUresult::CUDA_SUCCESS;
+ }
+ let base_result = cont(export_table_ptr, export_table_id);
+ if base_result != CUresult::CUDA_SUCCESS {
+ return base_result;
}
let export_table = (*export_table_ptr) as *mut *const c_void;
let boxed_guid = Box::new(*export_table_id);
@@ -745,6 +750,7 @@ unsafe fn override_export_table(
}
*export_table_ptr = override_table.as_ptr() as *const _;
overrides_map.insert(boxed_guid, override_table);
+ CUresult::CUDA_SUCCESS
}
const TOOLS_RUNTIME_CALLBACK_HOOKS_GUID: CUuuid = CUuuid {
@@ -761,6 +767,20 @@ const CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID: CUuuid = CUuuid {
],
};
+const CTX_CREATE_BYPASS_GUID: CUuuid = CUuuid {
+ bytes: [
+ 0x0C, 0xA5, 0x0B, 0x8C, 0x10, 0x04, 0x92, 0x9A, 0x89, 0xA7, 0xD0, 0xDF, 0x10, 0xE7, 0x72,
+ 0x86,
+ ],
+};
+
+const HEAP_ACCESS_GUID: CUuuid = CUuuid {
+ bytes: [
+ 0x19, 0x5B, 0xCB, 0xF4, 0xD6, 0x7D, 0x02, 0x4A, 0xAC, 0xC5, 0x1D, 0x29, 0xCE, 0xA6, 0x31,
+ 0xAE,
+ ],
+};
+
unsafe fn get_export_override_fn(
original_fn: *const c_void,
guid: *const CUuuid,
@@ -773,7 +793,10 @@ unsafe fn get_export_override_fn(
| (CUDART_INTERFACE_GUID, 7)
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 0)
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 1)
- | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 2) => original_fn,
+ | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 2)
+ | (CTX_CREATE_BYPASS_GUID, 1)
+ | (HEAP_ACCESS_GUID, 1)
+ | (HEAP_ACCESS_GUID, 2) => original_fn,
(CUDART_INTERFACE_GUID, 1) => {
ORIGINAL_GET_MODULE_FROM_CUBIN = mem::transmute(original_fn);
get_module_from_cubin as *const _
diff --git a/zluda_dump/src/os_unix.rs b/zluda_dump/src/os_unix.rs
index 2cf8dad..74543dd 100644
--- a/zluda_dump/src/os_unix.rs
+++ b/zluda_dump/src/os_unix.rs
@@ -33,7 +33,7 @@ macro_rules! os_log {
#[cfg(target_arch = "x86_64")]
pub fn get_thunk(
original_fn: *const c_void,
- report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
+ report_fn: unsafe extern "system" fn(*const CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs
index 55b69da..1617aa5 100644
--- a/zluda_dump/src/os_win.rs
+++ b/zluda_dump/src/os_win.rs
@@ -103,7 +103,7 @@ pub fn __log_impl(s: String) {
#[cfg(target_arch = "x86")]
pub fn get_thunk(
original_fn: *const c_void,
- report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
+ report_fn: unsafe extern "system" fn(*const CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
@@ -130,7 +130,7 @@ pub fn get_thunk(
#[cfg(target_arch = "x86_64")]
pub fn get_thunk(
original_fn: *const c_void,
- report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
+ report_fn: unsafe extern "system" fn(*const CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {