aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2022-02-04 14:14:51 +0100
committerAndrzej Janik <[email protected]>2022-02-04 14:14:51 +0100
commit164c172236a6fa9a84dafc0bd4887f6114478500 (patch)
tree9fb55d09df24e9e4a8da47f608a1b93514185229
parent2753d956df0ee3d68c3961f7b64e65df9f06bb0b (diff)
downloadZLUDA-164c172236a6fa9a84dafc0bd4887f6114478500.tar.gz
ZLUDA-164c172236a6fa9a84dafc0bd4887f6114478500.zip
Clean up ZLUDA redirection helper
-rw-r--r--zluda_redirect/src/lib.rs150
1 files changed, 61 insertions, 89 deletions
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs
index 522d705..a7d0464 100644
--- a/zluda_redirect/src/lib.rs
+++ b/zluda_redirect/src/lib.rs
@@ -3,11 +3,7 @@
extern crate detours_sys;
extern crate winapi;
-use std::{
- collections::HashMap,
- ffi::{c_void, CStr},
- mem, ptr, slice, usize,
-};
+use std::{ffi::c_void, mem, ptr, slice, usize};
use detours_sys::{
DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
@@ -18,6 +14,7 @@ use winapi::{
shared::minwindef::{BOOL, LPVOID},
um::{
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
+ libloaderapi::GetModuleFileNameW,
minwinbase::LPSECURITY_ATTRIBUTES,
processthreadsapi::{
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
@@ -32,15 +29,12 @@ use winapi::{
};
use winapi::{
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
- um::{
- libloaderapi::{GetModuleHandleA, LoadLibraryExA},
- winnt::LPCSTR,
- },
+ um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
};
use winapi::{
shared::minwindef::{FARPROC, HINSTANCE},
um::{
- libloaderapi::{GetModuleFileNameA, GetProcAddress},
+ libloaderapi::GetProcAddress,
processthreadsapi::{CreateProcessAsUserW, CreateProcessW},
winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW},
winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR},
@@ -158,15 +152,6 @@ unsafe extern "system" fn ZludaGetProcAddress_NoRedirect(
hModule: HMODULE,
lpProcName: LPCSTR,
) -> FARPROC {
- if let Some(detour_guard) = &DETOUR_STATE {
- if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule {
- let proc_name = CStr::from_ptr(lpProcName);
- return match detour_guard.overriden_cuda_fns.get(proc_name) {
- Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn),
- None => ptr::null_mut(),
- };
- }
- }
GetProcAddress(hModule, lpProcName)
}
@@ -384,8 +369,6 @@ struct DetourDetachGuard {
suspended_threads: Vec<*mut c_void>,
// First element is the original fn, second is the new fn
overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
- nvcuda_module: HMODULE,
- overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
}
impl DetourDetachGuard {
@@ -394,17 +377,11 @@ impl DetourDetachGuard {
// first element in the pair, because somehow otherwise original functions
// also get overriden, so for example ZludaLoadLibraryExW ends calling
// itself recursively until stack overflow exception occurs
- unsafe fn detour_functions<'a>(
- nvcuda_module: HMODULE,
- non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
- cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
- ) -> Option<Self> {
+ unsafe fn new<'a>() -> Option<Self> {
let mut result = DetourDetachGuard {
state: DetourUndoState::DoNothing,
suspended_threads: Vec::new(),
- overriden_non_cuda_fns: non_cuda_fns,
- nvcuda_module,
- overriden_cuda_fns: cuda_fns,
+ overriden_non_cuda_fns: Vec::new(),
};
if DetourTransactionBegin() != NO_ERROR as i32 {
return None;
@@ -420,6 +397,19 @@ impl DetourDetachGuard {
}
result.overriden_non_cuda_fns.extend_from_slice(&[
(
+ &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
+ ZludaLoadLibraryA as *mut c_void,
+ ),
+ (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
+ (
+ &mut LOAD_LIBRARY_EX_A as *mut _ as _,
+ ZludaLoadLibraryExA as _,
+ ),
+ (
+ &mut LOAD_LIBRARY_EX_W as *mut _ as _,
+ ZludaLoadLibraryExW as _,
+ ),
+ (
&mut CREATE_PROCESS_A as *mut _ as _,
ZludaCreateProcessA as _,
),
@@ -440,12 +430,7 @@ impl DetourDetachGuard {
ZludaCreateProcessWithTokenW as _,
),
]);
- for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain(
- result
- .overriden_cuda_fns
- .values_mut()
- .map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)),
- ) {
+ for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied() {
if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
return None;
}
@@ -659,23 +644,10 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
if DetourRestoreAfterWith() == FALSE {
return FALSE;
}
- if !initialize_current_module_name(instDLL) {
+ if !initialize_globals(instDLL) {
return FALSE;
}
- match get_zluda_dlls_paths() {
- Some((nvcuda_path, nvml_path)) => {
- ZLUDA_PATH_UTF8 = Some(nvcuda_path);
- ZLUDA_ML_PATH_UTF8 = Some(nvml_path);
- ZLUDA_PATH_UTF16 = std::str::from_utf8_unchecked(nvcuda_path)
- .encode_utf16()
- .collect::<Vec<_>>();
- ZLUDA_ML_PATH_UTF16 = std::str::from_utf8_unchecked(nvml_path)
- .encode_utf16()
- .collect::<Vec<_>>();
- }
- None => return FALSE,
- }
- match detour_already_loaded_nvcuda() {
+ match DetourDetachGuard::new() {
Some(g) => {
DETOUR_STATE = Some(g);
TRUE
@@ -692,55 +664,55 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
}
}
-#[must_use]
-unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool {
- let mut name = vec![0; 128 as usize];
+unsafe fn initialize_globals(current_module: HINSTANCE) -> bool {
+ let mut module_name = vec![0; 128 as usize];
loop {
- let size = GetModuleFileNameA(
+ let size = GetModuleFileNameW(
current_module,
- name.as_mut_ptr() as *mut _,
- name.len() as u32,
+ module_name.as_mut_ptr(),
+ module_name.len() as u32,
);
if size == 0 {
return false;
}
- if size < name.len() as u32 {
- name.truncate(size as usize);
- CURRENT_MODULE_FILENAME = name;
- return true;
+ if size < module_name.len() as u32 {
+ module_name.truncate(size as usize);
+ module_name.push(0);
+ CURRENT_MODULE_FILENAME = String::from_utf16_lossy(&module_name).into_bytes();
+ break;
}
- name.resize(name.len() * 2, 0);
+ module_name.resize(module_name.len() * 2, 0);
}
+ if !load_global_string(
+ &PAYLOAD_NVML_GUID,
+ &mut ZLUDA_ML_PATH_UTF8,
+ &mut ZLUDA_ML_PATH_UTF16,
+ ) {
+ return false;
+ }
+ if !load_global_string(
+ &PAYLOAD_NVCUDA_GUID,
+ &mut ZLUDA_PATH_UTF8,
+ &mut ZLUDA_PATH_UTF16,
+ ) {
+ return false;
+ }
+ true
}
-#[must_use]
-unsafe fn detour_already_loaded_nvcuda() -> Option<DetourDetachGuard> {
- let nvcuda_mod = GetModuleHandleA(b"nvcuda\0".as_ptr() as _);
- let detour_functions = vec![
- (
- &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
- ZludaLoadLibraryA as *mut c_void,
- ),
- (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
- (
- &mut LOAD_LIBRARY_EX_A as *mut _ as _,
- ZludaLoadLibraryExA as _,
- ),
- (
- &mut LOAD_LIBRARY_EX_W as *mut _ as _,
- ZludaLoadLibraryExW as _,
- ),
- ];
- DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, HashMap::new())
-}
-
-fn get_zluda_dlls_paths() -> Option<(&'static [u8], &'static [u8])> {
- match get_payload(&PAYLOAD_NVCUDA_GUID) {
- None => None,
- Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
- None => return None,
- Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)),
- },
+fn load_global_string(
+ guid: &detours_sys::GUID,
+ utf8_path: &mut Option<&'static [u8]>,
+ utf16_path: &mut Vec<u16>,
+) -> bool {
+ if let Some(payload) = get_payload(guid) {
+ *utf8_path = Some(payload);
+ *utf16_path = unsafe { std::str::from_utf8_unchecked(payload) }
+ .encode_utf16()
+ .collect::<Vec<_>>();
+ true
+ } else {
+ false
}
}