From 164c172236a6fa9a84dafc0bd4887f6114478500 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 4 Feb 2022 14:14:51 +0100 Subject: Clean up ZLUDA redirection helper --- zluda_redirect/src/lib.rs | 150 +++++++++++++++++++--------------------------- 1 file changed, 61 insertions(+), 89 deletions(-) diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index 522d705..a7d0464 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -3,11 +3,7 @@ extern crate detours_sys; extern crate winapi; -use std::{ - collections::HashMap, - ffi::{c_void, CStr}, - mem, ptr, slice, usize, -}; +use std::{ffi::c_void, mem, ptr, slice, usize}; use detours_sys::{ DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin, @@ -18,6 +14,7 @@ use winapi::{ shared::minwindef::{BOOL, LPVOID}, um::{ handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + libloaderapi::GetModuleFileNameW, minwinbase::LPSECURITY_ATTRIBUTES, processthreadsapi::{ CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread, @@ -32,15 +29,12 @@ use winapi::{ }; use winapi::{ shared::minwindef::{DWORD, FALSE, HMODULE, TRUE}, - um::{ - libloaderapi::{GetModuleHandleA, LoadLibraryExA}, - winnt::LPCSTR, - }, + um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR}, }; use winapi::{ shared::minwindef::{FARPROC, HINSTANCE}, um::{ - libloaderapi::{GetModuleFileNameA, GetProcAddress}, + libloaderapi::GetProcAddress, processthreadsapi::{CreateProcessAsUserW, CreateProcessW}, winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW}, winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR}, @@ -158,15 +152,6 @@ unsafe extern "system" fn ZludaGetProcAddress_NoRedirect( hModule: HMODULE, lpProcName: LPCSTR, ) -> FARPROC { - if let Some(detour_guard) = &DETOUR_STATE { - if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule { - let proc_name = CStr::from_ptr(lpProcName); - return match detour_guard.overriden_cuda_fns.get(proc_name) { - Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn), - None => ptr::null_mut(), - }; - } - } GetProcAddress(hModule, lpProcName) } @@ -384,8 +369,6 @@ struct DetourDetachGuard { suspended_threads: Vec<*mut c_void>, // First element is the original fn, second is the new fn overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>, - nvcuda_module: HMODULE, - overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>, } impl DetourDetachGuard { @@ -394,17 +377,11 @@ impl DetourDetachGuard { // first element in the pair, because somehow otherwise original functions // also get overriden, so for example ZludaLoadLibraryExW ends calling // itself recursively until stack overflow exception occurs - unsafe fn detour_functions<'a>( - nvcuda_module: HMODULE, - non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>, - cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>, - ) -> Option { + unsafe fn new<'a>() -> Option { let mut result = DetourDetachGuard { state: DetourUndoState::DoNothing, suspended_threads: Vec::new(), - overriden_non_cuda_fns: non_cuda_fns, - nvcuda_module, - overriden_cuda_fns: cuda_fns, + overriden_non_cuda_fns: Vec::new(), }; if DetourTransactionBegin() != NO_ERROR as i32 { return None; @@ -419,6 +396,19 @@ impl DetourDetachGuard { } } result.overriden_non_cuda_fns.extend_from_slice(&[ + ( + &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, + ZludaLoadLibraryA as *mut c_void, + ), + (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _), + ( + &mut LOAD_LIBRARY_EX_A as *mut _ as _, + ZludaLoadLibraryExA as _, + ), + ( + &mut LOAD_LIBRARY_EX_W as *mut _ as _, + ZludaLoadLibraryExW as _, + ), ( &mut CREATE_PROCESS_A as *mut _ as _, ZludaCreateProcessA as _, @@ -440,12 +430,7 @@ impl DetourDetachGuard { ZludaCreateProcessWithTokenW as _, ), ]); - for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain( - result - .overriden_cuda_fns - .values_mut() - .map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)), - ) { + for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied() { if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 { return None; } @@ -659,23 +644,10 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u if DetourRestoreAfterWith() == FALSE { return FALSE; } - if !initialize_current_module_name(instDLL) { + if !initialize_globals(instDLL) { return FALSE; } - match get_zluda_dlls_paths() { - Some((nvcuda_path, nvml_path)) => { - ZLUDA_PATH_UTF8 = Some(nvcuda_path); - ZLUDA_ML_PATH_UTF8 = Some(nvml_path); - ZLUDA_PATH_UTF16 = std::str::from_utf8_unchecked(nvcuda_path) - .encode_utf16() - .collect::>(); - ZLUDA_ML_PATH_UTF16 = std::str::from_utf8_unchecked(nvml_path) - .encode_utf16() - .collect::>(); - } - None => return FALSE, - } - match detour_already_loaded_nvcuda() { + match DetourDetachGuard::new() { Some(g) => { DETOUR_STATE = Some(g); TRUE @@ -692,55 +664,55 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u } } -#[must_use] -unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool { - let mut name = vec![0; 128 as usize]; +unsafe fn initialize_globals(current_module: HINSTANCE) -> bool { + let mut module_name = vec![0; 128 as usize]; loop { - let size = GetModuleFileNameA( + let size = GetModuleFileNameW( current_module, - name.as_mut_ptr() as *mut _, - name.len() as u32, + module_name.as_mut_ptr(), + module_name.len() as u32, ); if size == 0 { return false; } - if size < name.len() as u32 { - name.truncate(size as usize); - CURRENT_MODULE_FILENAME = name; - return true; + if size < module_name.len() as u32 { + module_name.truncate(size as usize); + module_name.push(0); + CURRENT_MODULE_FILENAME = String::from_utf16_lossy(&module_name).into_bytes(); + break; } - name.resize(name.len() * 2, 0); + module_name.resize(module_name.len() * 2, 0); } + if !load_global_string( + &PAYLOAD_NVML_GUID, + &mut ZLUDA_ML_PATH_UTF8, + &mut ZLUDA_ML_PATH_UTF16, + ) { + return false; + } + if !load_global_string( + &PAYLOAD_NVCUDA_GUID, + &mut ZLUDA_PATH_UTF8, + &mut ZLUDA_PATH_UTF16, + ) { + return false; + } + true } -#[must_use] -unsafe fn detour_already_loaded_nvcuda() -> Option { - let nvcuda_mod = GetModuleHandleA(b"nvcuda\0".as_ptr() as _); - let detour_functions = vec![ - ( - &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, - ZludaLoadLibraryA as *mut c_void, - ), - (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _), - ( - &mut LOAD_LIBRARY_EX_A as *mut _ as _, - ZludaLoadLibraryExA as _, - ), - ( - &mut LOAD_LIBRARY_EX_W as *mut _ as _, - ZludaLoadLibraryExW as _, - ), - ]; - DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, HashMap::new()) -} - -fn get_zluda_dlls_paths() -> Option<(&'static [u8], &'static [u8])> { - match get_payload(&PAYLOAD_NVCUDA_GUID) { - None => None, - Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) { - None => return None, - Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)), - }, +fn load_global_string( + guid: &detours_sys::GUID, + utf8_path: &mut Option<&'static [u8]>, + utf16_path: &mut Vec, +) -> bool { + if let Some(payload) = get_payload(guid) { + *utf8_path = Some(payload); + *utf16_path = unsafe { std::str::from_utf8_unchecked(payload) } + .encode_utf16() + .collect::>(); + true + } else { + false } } -- cgit v1.2.3