From 2c6d7ffb7a68514dbfa97095516681d843802012 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 5 Dec 2021 23:01:46 +0100 Subject: Fix remaining issues with detouring nvcuda --- zluda_dump/src/os_win.rs | 1 - zluda_inject/Cargo.toml | 4 +- zluda_redirect/Cargo.toml | 3 +- zluda_redirect/src/lib.rs | 133 +++++++++++++++++++++++++++++++++++++--------- 4 files changed, 112 insertions(+), 29 deletions(-) diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index 2bfc457..d80f6e6 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -66,7 +66,6 @@ impl PlatformLibrary { if module == ptr::null_mut() { break; } - let mut size = 0; let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _); if payload != ptr::null_mut() { return Some(module as _); diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml index cc9dc9b..73489bb 100644 --- a/zluda_inject/Cargo.toml +++ b/zluda_inject/Cargo.toml @@ -13,7 +13,7 @@ winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchap detours-sys = { path = "../detours-sys" } [dev-dependencies] -# dependency for integration tests +# all of those are used in integration tests zluda_redirect = { path = "../zluda_redirect" } -# dependency for integration tests zluda_dump = { path = "../zluda_dump" } +zluda_ml = { path = "../zluda_ml" } diff --git a/zluda_redirect/Cargo.toml b/zluda_redirect/Cargo.toml index 85c2976..97fb3e2 100644 --- a/zluda_redirect/Cargo.toml +++ b/zluda_redirect/Cargo.toml @@ -10,4 +10,5 @@ crate-type = ["cdylib"] [target.'cfg(windows)'.dependencies] detours-sys = { path = "../detours-sys" } wchar = "0.6" -winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] } \ No newline at end of file +winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] } +tempfile = "3" \ No newline at end of file diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index f2d2739..657500c 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -6,33 +6,18 @@ extern crate winapi; use std::{ collections::HashMap, ffi::{c_void, CStr}, - mem, + io, mem, os::raw::c_uint, ptr, slice, usize, }; use detours_sys::{ - DetourAttach, DetourEnumerateExports, DetourRestoreAfterWith, DetourTransactionAbort, - DetourTransactionBegin, DetourTransactionCommit, DetourUpdateProcessWithDll, - DetourUpdateThread, + DetourAttach, DetourEnumerateExports, DetourGetEntryPoint, DetourRestoreAfterWith, + DetourTransactionAbort, DetourTransactionBegin, DetourTransactionCommit, + DetourUpdateProcessWithDll, DetourUpdateThread, }; +use tempfile::TempDir; use wchar::wch; -use winapi::{ - shared::minwindef::{BOOL, LPVOID}, - um::{ - handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, - minwinbase::LPSECURITY_ATTRIBUTES, - processthreadsapi::{ - CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread, - SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW, - }, - tlhelp32::{ - CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32, - }, - winbase::CREATE_SUSPENDED, - winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME}, - }, -}; use winapi::{ shared::minwindef::{DWORD, FALSE, HMODULE, TRUE}, um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR}, @@ -50,6 +35,26 @@ use winapi::{ shared::winerror::NO_ERROR, um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW}, }; +use winapi::{ + shared::{ + minwindef::{BOOL, LPVOID}, + winerror::E_UNEXPECTED, + }, + um::{ + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + libloaderapi::GetModuleHandleW, + minwinbase::LPSECURITY_ATTRIBUTES, + processthreadsapi::{ + CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread, + SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW, + }, + tlhelp32::{ + CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32, + }, + winbase::{CopyFileW, CreateSymbolicLinkW, CREATE_SUSPENDED}, + winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME}, + }, +}; include!("payload_guid.rs"); @@ -375,6 +380,59 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW( continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation) } +static mut MAIN: unsafe extern "system" fn() -> DWORD = ZludaMain; + +// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications +// "If a DLL with the same module name is already loaded in memory, the system +// uses the loaded DLL, no matter which directory it is in. The system does not +// search for the DLL." +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaMain() -> DWORD { + let temp_dir = match do_zluda_preload() { + Ok(f) => f, + Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as u32, + }; + let result = MAIN(); + drop(temp_dir); + result +} + +unsafe fn do_zluda_preload() -> std::io::Result { + let temp_dir = tempfile::tempdir()?; + do_single_zluda_preload(&temp_dir, ZLUDA_PATH_UTF16.unwrap().as_ptr(), NVCUDA_UTF8)?; + do_single_zluda_preload(&temp_dir, ZLUDA_ML_PATH_UTF16.unwrap().as_ptr(), NVML_UTF8)?; + Ok(temp_dir) +} + +unsafe fn do_single_zluda_preload( + temp_dir: &TempDir, + full_path: *const u16, + file_name: &'static str, +) -> io::Result<()> { + let mut temp_file_path = temp_dir.path().to_path_buf(); + temp_file_path.push(file_name); + let mut temp_file_path_utf16 = temp_file_path + .into_os_string() + .to_string_lossy() + .encode_utf16() + .collect::>(); + temp_file_path_utf16.push(0); + // Probably we are not in developer mode, do a copty then + if 0 == CreateSymbolicLinkW( + temp_file_path_utf16.as_ptr(), + full_path, + 0x2, //SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE + ) { + if 0 == CopyFileW(full_path, temp_file_path_utf16.as_ptr(), 1) { + return Err(io::Error::last_os_error()); + } + } + if ptr::null_mut() == ZludaLoadLibraryW_NoRedirect(temp_file_path_utf16.as_ptr()) { + return Err(io::Error::last_os_error()); + } + Ok(()) +} + // This type encapsulates typical calling sequence of detours and cleanup. // We have two ways we do detours: // * If we are loaded before nvcuda.dll, we hook LoadLibrary* @@ -668,8 +726,8 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u // redirecting LoadLibrary* to load ZLUDA, we override already loaded // functions let detach_guard = match get_cuinit() { - Some((nvcuda_mod, _)) => attach_cuinit(nvcuda_mod), - None => attach_load_libary(), + Some((nvcuda_mod, _)) => detour_already_loaded_nvcuda(nvcuda_mod), + None => detour_main(), }; match detach_guard { Some(g) => { @@ -724,7 +782,7 @@ unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> { } #[must_use] -unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option { +unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option { let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr()); if zluda_module == ptr::null_mut() { return None; @@ -747,7 +805,22 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option { (original_fn_address as _, override_fn_address), ); } - DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs) + let detour_functions = vec![ + ( + &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, + ZludaLoadLibraryA as *mut c_void, + ), + (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _), + ( + &mut LOAD_LIBRARY_EX_A as *mut _ as _, + ZludaLoadLibraryExA as _, + ), + ( + &mut LOAD_LIBRARY_EX_W as *mut _ as _, + ZludaLoadLibraryExW as _, + ), + ]; + DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, override_fn_pairs) } unsafe extern "system" fn cuda_unsupported() -> c_uint { @@ -776,8 +849,18 @@ unsafe extern "stdcall" fn gather_imports_impl( } #[must_use] -unsafe fn attach_load_libary() -> Option { +unsafe fn detour_main() -> Option { + let exe_handle = GetModuleHandleW(ptr::null()); + let entry_point = DetourGetEntryPoint(exe_handle as _); + if entry_point == ptr::null_mut() { + return None; + } + MAIN = mem::transmute(entry_point); let detour_functions = vec![ + ( + &mut MAIN as *mut _ as *mut *mut c_void, + ZludaMain as *mut c_void, + ), ( &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, ZludaLoadLibraryA as *mut c_void, -- cgit v1.2.3