diff options
author | Andrzej Janik <[email protected]> | 2021-02-20 21:40:19 +0100 |
---|---|---|
committer | GitHub <[email protected]> | 2021-02-20 21:40:19 +0100 |
commit | 36514bd6ebcc22fde93e1fc52b3e336da0683bfb (patch) | |
tree | 2bf3da943fddb7c4029fe9bd234ad1ceac00e8ab | |
parent | 972f612562dc534ad605bfc5a00bc908ddd8b3de (diff) | |
download | ZLUDA-36514bd6ebcc22fde93e1fc52b3e336da0683bfb.tar.gz ZLUDA-36514bd6ebcc22fde93e1fc52b3e336da0683bfb.zip |
Improve ZLUDA injection (#37)
Improve injector&redirector so it's no longer required to manually mess with files if the application links nvcuda.dll. Additionally inject into child processes
-rw-r--r-- | README.md | 7 | ||||
-rw-r--r-- | zluda_dump/Cargo.toml | 2 | ||||
-rw-r--r-- | zluda_inject/Cargo.toml | 2 | ||||
-rw-r--r-- | zluda_redirect/Cargo.toml | 2 | ||||
-rw-r--r-- | zluda_redirect/src/lib.rs | 730 |
5 files changed, 664 insertions, 79 deletions
@@ -52,13 +52,16 @@ Overall, ZLUDA is slower in GeekBench by roughly 2%. ### Windows You should have the most recent Intel GPU drivers installed.\ -Copy `nvcuda.dll` to the application directory (the directory where .exe file is) and launch it normally +Run your application like this: +``` +<ZLUDA_DIRECTORY>\zluda_with.exe -- <APPLICATION> <APPLICATIONS_ARGUMENTS> +``` ### Linux A very recent version of [compute-runtime](https://github.com/intel/compute-runtime) and [Level Zero loader](https://github.com/oneapi-src/level-zero/releases) is required. At the time of the writing 20.45.18403 is the oldest recommended version. Run your application like this: ``` -LD_LIBRARY_PATH=<PATH_TO_THE_DIRECTORY_WITH_ZLUDA_PROVIDED_LIBCUDA> <YOUR_APPLICATION> +LD_LIBRARY_PATH=<ZLUDA_DIRECTORY> <APPLICATION> <APPLICATIONS_ARGUMENTS> ``` ## Building diff --git a/zluda_dump/Cargo.toml b/zluda_dump/Cargo.toml index b81ef13..80c7ddc 100644 --- a/zluda_dump/Cargo.toml +++ b/zluda_dump/Cargo.toml @@ -14,7 +14,7 @@ lz4-sys = "1.9" regex = "1.4"
[target.'cfg(windows)'.dependencies]
-winapi = { version = "0.3", features = ["libloaderapi", "debugapi"] }
+winapi = { version = "0.3", features = ["libloaderapi", "debugapi", "std"] }
wchar = "0.6"
detours-sys = { path = "../detours-sys" }
diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml index 7576e08..1181a21 100644 --- a/zluda_inject/Cargo.toml +++ b/zluda_inject/Cargo.toml @@ -9,5 +9,5 @@ name = "zluda_with" path = "src/main.rs" [target.'cfg(windows)'.dependencies] -winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "std", "synchapi", "winbase"] } +winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] } detours-sys = { path = "../detours-sys" } diff --git a/zluda_redirect/Cargo.toml b/zluda_redirect/Cargo.toml index 1b6b958..130e244 100644 --- a/zluda_redirect/Cargo.toml +++ b/zluda_redirect/Cargo.toml @@ -10,4 +10,4 @@ crate-type = ["cdylib"] [target.'cfg(windows)'.dependencies] detours-sys = { path = "../detours-sys" } wchar = "0.6" -winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "std"] }
\ No newline at end of file +winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "std"] }
\ No newline at end of file diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index 04b2413..5de7530 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -3,20 +3,50 @@ extern crate detours_sys; extern crate winapi; -use std::{mem, ptr, slice}; +use std::{ + ffi::c_void, + mem, + os::raw::{c_int, c_uint, c_ulong}, + ptr, slice, usize, +}; use detours_sys::{ - DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionBegin, - DetourTransactionCommit, DetourUpdateThread, + DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionAbort, + DetourTransactionBegin, DetourTransactionCommit, DetourUpdateProcessWithDll, + DetourUpdateThread, }; use wchar::wch; -use winapi::um::processthreadsapi::GetCurrentThread; -use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR}; +use winapi::{ + shared::minwindef::{BOOL, LPVOID}, + um::{ + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + minwinbase::LPSECURITY_ATTRIBUTES, + processthreadsapi::{ + CreateProcessA, GetCurrentProcessId, GetCurrentThread, GetCurrentThreadId, OpenThread, + ResumeThread, SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, + LPSTARTUPINFOW, + }, + tlhelp32::{ + CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32, + }, + winbase::CREATE_SUSPENDED, + winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME}, + }, +}; use winapi::{ shared::minwindef::{DWORD, FALSE, HMODULE, TRUE}, um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR}, }; use winapi::{ + shared::minwindef::{FARPROC, HINSTANCE}, + um::{ + libloaderapi::{GetModuleFileNameA, GetProcAddress}, + processthreadsapi::{CreateProcessAsUserW, CreateProcessW}, + winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW}, + winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR}, + }, +}; +use winapi::{ shared::winerror::NO_ERROR, um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW}, }; @@ -27,6 +57,12 @@ const NVCUDA_UTF8: &'static str = "NVCUDA.DLL"; const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL"); static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new(); static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None; +static mut DETACH_LOAD_LIBRARY: bool = false; +static mut NVCUDA_ORIGINAL_MODULE: HMODULE = ptr::null_mut(); +static mut CUINIT_ORIGINAL_FN: FARPROC = ptr::null_mut(); +static mut CURRENT_MODULE_FILENAME: Vec<u8> = Vec::new(); +const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801; +const CUDA_ERROR_UNKNOWN: c_uint = 999; static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE = LoadLibraryA; @@ -40,6 +76,72 @@ static mut LOAD_LIBRARY_EX_A: unsafe extern "system" fn( dwFlags: DWORD, ) -> HMODULE = LoadLibraryExA; +static mut CREATE_PROCESS_A: unsafe extern "system" fn( + lpApplicationName: LPCSTR, + lpCommandLine: LPSTR, + lpProcessAttributes: LPSECURITY_ATTRIBUTES, + lpThreadAttributes: LPSECURITY_ATTRIBUTES, + bInheritHandles: BOOL, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCSTR, + lpStartupInfo: LPSTARTUPINFOA, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL = CreateProcessA; + +static mut CREATE_PROCESS_W: unsafe extern "system" fn( + lpApplicationName: LPCWSTR, + lpCommandLine: LPWSTR, + lpProcessAttributes: LPSECURITY_ATTRIBUTES, + lpThreadAttributes: LPSECURITY_ATTRIBUTES, + bInheritHandles: BOOL, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCWSTR, + lpStartupInfo: LPSTARTUPINFOW, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL = CreateProcessW; + +static mut CREATE_PROCESS_AS_USER_W: unsafe extern "system" fn( + hToken: HANDLE, + lpApplicationName: LPCWSTR, + lpCommandLine: LPWSTR, + lpProcessAttributes: LPSECURITY_ATTRIBUTES, + lpThreadAttributes: LPSECURITY_ATTRIBUTES, + bInheritHandles: BOOL, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCWSTR, + lpStartupInfo: LPSTARTUPINFOW, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL = CreateProcessAsUserW; + +static mut CREATE_PROCESS_WITH_TOKEN_W: unsafe extern "system" fn( + hToken: HANDLE, + dwLogonFlags: DWORD, + lpApplicationName: LPCWSTR, + lpCommandLine: LPWSTR, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCWSTR, + lpStartupInfo: LPSTARTUPINFOW, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL = CreateProcessWithTokenW; + +static mut CREATE_PROCESS_WITH_LOGON_W: unsafe extern "system" fn( + lpUsername: LPCWSTR, + lpDomain: LPCWSTR, + lpPassword: LPCWSTR, + dwLogonFlags: DWORD, + lpApplicationName: LPCWSTR, + lpCommandLine: LPWSTR, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCWSTR, + lpStartupInfo: LPSTARTUPINFOW, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL = CreateProcessWithLogonW; + static mut LOAD_LIBRARY_EX_W: unsafe extern "system" fn( lpLibFileName: LPCWSTR, hFile: HANDLE, @@ -100,6 +202,293 @@ unsafe extern "system" fn ZludaLoadLibraryExW( (LOAD_LIBRARY_EX_W)(nvcuda_file_name, hFile, dwFlags) } +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaCreateProcessA( + lpApplicationName: LPCSTR, + lpCommandLine: LPSTR, + lpProcessAttributes: LPSECURITY_ATTRIBUTES, + lpThreadAttributes: LPSECURITY_ATTRIBUTES, + bInheritHandles: BOOL, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCSTR, + lpStartupInfo: LPSTARTUPINFOA, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL { + let create_proc_result = CREATE_PROCESS_A( + lpApplicationName, + lpCommandLine, + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags | CREATE_SUSPENDED, + lpEnvironment, + lpCurrentDirectory, + lpStartupInfo, + lpProcessInformation, + ); + continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation) +} + +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaCreateProcessW( + lpApplicationName: LPCWSTR, + lpCommandLine: LPWSTR, + lpProcessAttributes: LPSECURITY_ATTRIBUTES, + lpThreadAttributes: LPSECURITY_ATTRIBUTES, + bInheritHandles: BOOL, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCWSTR, + lpStartupInfo: LPSTARTUPINFOW, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL { + let create_proc_result = CREATE_PROCESS_W( + lpApplicationName, + lpCommandLine, + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags | CREATE_SUSPENDED, + lpEnvironment, + lpCurrentDirectory, + lpStartupInfo, + lpProcessInformation, + ); + continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation) +} + +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaCreateProcessAsUserW( + hToken: HANDLE, + lpApplicationName: LPCWSTR, + lpCommandLine: LPWSTR, + lpProcessAttributes: LPSECURITY_ATTRIBUTES, + lpThreadAttributes: LPSECURITY_ATTRIBUTES, + bInheritHandles: BOOL, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCWSTR, + lpStartupInfo: LPSTARTUPINFOW, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL { + let create_proc_result = CREATE_PROCESS_AS_USER_W( + hToken, + lpApplicationName, + lpCommandLine, + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags | CREATE_SUSPENDED, + lpEnvironment, + lpCurrentDirectory, + lpStartupInfo, + lpProcessInformation, + ); + continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation) +} + +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaCreateProcessWithLogonW( + lpUsername: LPCWSTR, + lpDomain: LPCWSTR, + lpPassword: LPCWSTR, + dwLogonFlags: DWORD, + lpApplicationName: LPCWSTR, + lpCommandLine: LPWSTR, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCWSTR, + lpStartupInfo: LPSTARTUPINFOW, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL { + let create_proc_result = CREATE_PROCESS_WITH_LOGON_W( + lpUsername, + lpDomain, + lpPassword, + dwLogonFlags, + lpApplicationName, + lpCommandLine, + dwCreationFlags | CREATE_SUSPENDED, + lpEnvironment, + lpCurrentDirectory, + lpStartupInfo, + lpProcessInformation, + ); + continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation) +} + +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaCreateProcessWithTokenW( + hToken: HANDLE, + dwLogonFlags: DWORD, + lpApplicationName: LPCWSTR, + lpCommandLine: LPWSTR, + dwCreationFlags: DWORD, + lpEnvironment: LPVOID, + lpCurrentDirectory: LPCWSTR, + lpStartupInfo: LPSTARTUPINFOW, + lpProcessInformation: LPPROCESS_INFORMATION, +) -> BOOL { + let create_proc_result = CREATE_PROCESS_WITH_TOKEN_W( + hToken, + dwLogonFlags, + lpApplicationName, + lpCommandLine, + dwCreationFlags, + lpEnvironment, + lpCurrentDirectory, + lpStartupInfo, + lpProcessInformation, + ); + continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation) +} + +unsafe fn continue_create_process_hook( + create_proc_result: BOOL, + creation_flags: DWORD, + process_information: LPPROCESS_INFORMATION, +) -> BOOL { + if create_proc_result == 0 { + return 0; + } + if DetourUpdateProcessWithDll( + (*process_information).hProcess, + &mut CURRENT_MODULE_FILENAME.as_ptr() as *mut _ as *mut _, + 1, + ) == 0 + { + TerminateProcess((*process_information).hProcess, 1); + return 0; + } + if detours_sys::DetourCopyPayloadToProcess( + (*process_information).hProcess, + &PAYLOAD_GUID, + ZLUDA_PATH_UTF16.unwrap().as_ptr() as *mut _, + (ZLUDA_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32, + ) == FALSE + { + TerminateProcess((*process_information).hProcess, 1); + return 0; + } + + if creation_flags & CREATE_SUSPENDED == 0 { + if ResumeThread((*process_information).hThread) == -1i32 as u32 { + TerminateProcess((*process_information).hProcess, 1); + return 0; + } + } + create_proc_result +} + +unsafe extern "C" fn cuinit_detour(flags: c_uint) -> c_uint { + let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr()); + if zluda_module == ptr::null_mut() { + return CUDA_ERROR_UNKNOWN; + } + let suspended_threads = suspend_all_threads_except_current(); + let suspended_threads = match suspended_threads { + Some(t) => t, + None => return CUDA_ERROR_UNKNOWN, + }; + if DetourTransactionBegin() != NO_ERROR as i32 { + resume_threads(&suspended_threads); + return CUDA_ERROR_UNKNOWN; + } + for t in suspended_threads.iter() { + if DetourUpdateThread(*t) != NO_ERROR as i32 { + DetourTransactionAbort(); + resume_threads(&suspended_threads); + return CUDA_ERROR_UNKNOWN; + } + } + if detours_sys::DetourEnumerateExports( + NVCUDA_ORIGINAL_MODULE as *mut _, + &zluda_module as *const _ as *mut _, + Some(override_nvcuda_export), + ) == FALSE + { + DetourTransactionAbort(); + resume_threads(&suspended_threads); + return CUDA_ERROR_UNKNOWN; + } + if DetourTransactionCommit() != NO_ERROR as i32 { + DetourTransactionAbort(); + resume_threads(&suspended_threads); + return CUDA_ERROR_UNKNOWN; + } + resume_threads(&suspended_threads); + let zluda_cuinit = GetProcAddress(zluda_module, b"cuInit\0".as_ptr() as *const _); + (mem::transmute::<_, unsafe extern "C" fn(c_uint) -> c_uint>(zluda_cuinit))(flags) +} + +unsafe fn suspend_all_threads_except_current() -> Option<Vec<*mut c_void>> { + let thread_snap = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0); + if thread_snap == INVALID_HANDLE_VALUE { + return None; + } + let current_thread = GetCurrentThreadId(); + let current_process = GetCurrentProcessId(); + let mut threads = Vec::new(); + let mut thread = mem::zeroed::<THREADENTRY32>(); + thread.dwSize = mem::size_of::<THREADENTRY32>() as u32; + if Thread32First(thread_snap, &mut thread) == 0 { + CloseHandle(thread_snap); + return None; + } + loop { + if thread.th32OwnerProcessID == current_process && thread.th32ThreadID != current_thread { + let thread_handle = OpenThread(THREAD_SUSPEND_RESUME, 0, thread.th32ThreadID); + if thread_handle == ptr::null_mut() { + CloseHandle(thread_snap); + resume_threads(&threads); + return None; + } + if SuspendThread(thread_handle) == (-1i32 as u32) { + CloseHandle(thread_snap); + resume_threads(&threads); + return None; + } + threads.push(thread_handle); + } + if Thread32Next(thread_snap, &mut thread) == 0 { + break; + } + } + CloseHandle(thread_snap); + Some(threads) +} + +unsafe fn resume_threads(threads: &[*mut c_void]) { + for t in threads { + ResumeThread(*t); + CloseHandle(*t); + } +} + +unsafe extern "C" fn override_nvcuda_export( + context_ptr: *mut c_void, + _: c_ulong, + name: LPCSTR, + mut address: *mut c_void, +) -> c_int { + let zluda_module: HMODULE = *(context_ptr as *mut HMODULE); + let mut zluda_fn = GetProcAddress(zluda_module, name); + if zluda_fn == ptr::null_mut() { + // We only support 64 bits and in all relevant calling conventions stack + // is caller-cleaned, so probably we will not crash + zluda_fn = unsupported_cuda_fn as *mut _; + } + if DetourAttach((&mut address) as *mut _, zluda_fn as *mut _) != NO_ERROR as i32 { + return FALSE; + } + TRUE +} + +unsafe extern "C" fn unsupported_cuda_fn() -> c_uint { + CUDA_ERROR_NOT_SUPPORTED +} + fn is_nvcuda_dll_utf8(lib: *const u8) -> bool { is_nvcuda_dll(lib, 0, NVCUDA_UTF8.as_bytes(), |c| { if c >= 'a' as u8 && c <= 'z' as u8 { @@ -147,94 +536,207 @@ fn is_nvcuda_dll<T: Copy + PartialEq>( #[allow(non_snake_case)] #[no_mangle] -unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) -> i32 { +unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u8) -> i32 { if dwReason == DLL_PROCESS_ATTACH { if DetourRestoreAfterWith() == FALSE { return FALSE; } + if !initialize_current_module_name(instDLL) { + return FALSE; + } match get_zluda_dll_path() { Some(path) => { ZLUDA_PATH_UTF16 = Some(path); + // from_utf16_lossy(...) handles terminating NULL correctly ZLUDA_PATH_UTF8 = String::from_utf16_lossy(path).into_bytes(); } None => return FALSE, } - if DetourTransactionBegin() != NO_ERROR as i32 { - return FALSE; - } - if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 { - return FALSE; - } - if DetourAttach( - mem::transmute(&mut LOAD_LIBRARY_A), - ZludaLoadLibraryA as *mut _, - ) != NO_ERROR as i32 - { - return FALSE; - } - if DetourAttach( - mem::transmute(&mut LOAD_LIBRARY_W), - ZludaLoadLibraryW as *mut _, - ) != NO_ERROR as i32 - { - return FALSE; - } - if DetourAttach( - mem::transmute(&mut LOAD_LIBRARY_EX_A), - ZludaLoadLibraryExA as *mut _, - ) != NO_ERROR as i32 - { - return FALSE; - } - if DetourAttach( - mem::transmute(&mut LOAD_LIBRARY_EX_W), - ZludaLoadLibraryExW as *mut _, - ) != NO_ERROR as i32 - { - return FALSE; - } - if DetourTransactionCommit() != NO_ERROR as i32 { - return FALSE; + // If the application (directly or not) links to nvcuda.dll, nvcuda.dll + // will get loaded before we can act. In this case, instead of + // redirecting LoadLibrary* to load ZLUDA, we redirect cuInit to + // a cuInit implementation that will load ZLUDA and set up detouts. + // We can't do it here because LoadLibrary* inside DllMain is illegal. + // We greatly prefer wholesale redirecting inside LoadLibrary*. + // Hooking inside cuInit is brittle in the face of multiple + // threads (DetourUpdateThread) + match get_cuinit() { + Some((nvcuda_mod, cuinit_fn)) => attach_cuinit(nvcuda_mod, cuinit_fn), + None => attach_load_libary(), } } else if dwReason == DLL_PROCESS_DETACH { - if DetourTransactionBegin() != NO_ERROR as i32 { - return FALSE; - } - if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 { - return FALSE; - } - if DetourDetach( - mem::transmute(&mut LOAD_LIBRARY_A), - ZludaLoadLibraryA as *mut _, - ) != NO_ERROR as i32 - { - return FALSE; + if DETACH_LOAD_LIBRARY { + detach_load_library() + } else { + detach_cuinit() } - if DetourDetach( - mem::transmute(&mut LOAD_LIBRARY_W), - ZludaLoadLibraryW as *mut _, - ) != NO_ERROR as i32 - { - return FALSE; + } else { + TRUE + } +} + +#[must_use] +unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool { + let mut name = vec![0; 128 as usize]; + loop { + let size = GetModuleFileNameA( + current_module, + name.as_mut_ptr() as *mut _, + name.len() as u32, + ); + if size == 0 { + return false; } - if DetourDetach( - mem::transmute(&mut LOAD_LIBRARY_EX_A), - ZludaLoadLibraryExA as *mut _, - ) != NO_ERROR as i32 - { - return FALSE; + if size < name.len() as u32 { + name.truncate(size as usize); + CURRENT_MODULE_FILENAME = name; + return true; } - if DetourDetach( - mem::transmute(&mut LOAD_LIBRARY_EX_W), - ZludaLoadLibraryExW as *mut _, - ) != NO_ERROR as i32 - { - return FALSE; + name.resize(name.len() * 2, 0); + } +} + +unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> { + let mut module = ptr::null_mut(); + loop { + module = detours_sys::DetourEnumerateModules(module); + if module == ptr::null_mut() { + return None; } - if DetourTransactionCommit() != NO_ERROR as i32 { - return FALSE; + let cuinit_addr = GetProcAddress(module as *mut _, b"cuInit\0".as_ptr() as *const _); + if cuinit_addr != ptr::null_mut() { + return Some((module as *mut _, cuinit_addr)); } } +} + +#[must_use] +unsafe fn attach_cuinit(nvcuda_mod: HMODULE, mut cuinit: FARPROC) -> i32 { + if DetourTransactionBegin() != NO_ERROR as i32 { + return FALSE; + } + if !attach_create_process() { + return FALSE; + } + NVCUDA_ORIGINAL_MODULE = nvcuda_mod; + CUINIT_ORIGINAL_FN = cuinit; + if DetourAttach(mem::transmute(&mut cuinit), cuinit_detour as *mut _) != NO_ERROR as i32 { + return FALSE; + } + if DetourTransactionCommit() != NO_ERROR as i32 { + return FALSE; + } + TRUE +} + +#[must_use] +unsafe fn detach_cuinit() -> i32 { + if DetourTransactionBegin() != NO_ERROR as i32 { + return FALSE; + } + if !detach_create_process() { + return FALSE; + } + if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut CUINIT_ORIGINAL_FN), + cuinit_detour as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourTransactionCommit() != NO_ERROR as i32 { + return FALSE; + } + TRUE +} + +#[must_use] +unsafe fn attach_load_libary() -> i32 { + if DetourTransactionBegin() != NO_ERROR as i32 { + return FALSE; + } + if !attach_create_process() { + return FALSE; + } + if DetourAttach( + mem::transmute(&mut LOAD_LIBRARY_A), + ZludaLoadLibraryA as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourAttach( + mem::transmute(&mut LOAD_LIBRARY_W), + ZludaLoadLibraryW as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourAttach( + mem::transmute(&mut LOAD_LIBRARY_EX_A), + ZludaLoadLibraryExA as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourAttach( + mem::transmute(&mut LOAD_LIBRARY_EX_W), + ZludaLoadLibraryExW as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourTransactionCommit() != NO_ERROR as i32 { + return FALSE; + } + TRUE +} + +#[must_use] +unsafe fn detach_load_library() -> i32 { + if DetourTransactionBegin() != NO_ERROR as i32 { + return FALSE; + } + if !detach_create_process() { + return FALSE; + } + if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut LOAD_LIBRARY_A), + ZludaLoadLibraryA as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut LOAD_LIBRARY_W), + ZludaLoadLibraryW as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut LOAD_LIBRARY_EX_A), + ZludaLoadLibraryExA as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut LOAD_LIBRARY_EX_W), + ZludaLoadLibraryExW as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourTransactionCommit() != NO_ERROR as i32 { + return FALSE; + } TRUE } @@ -258,3 +760,83 @@ fn get_zluda_dll_path() -> Option<&'static [u16]> { } None } + +#[must_use] +unsafe fn attach_create_process() -> bool { + if DetourAttach( + mem::transmute(&mut CREATE_PROCESS_A), + ZludaCreateProcessA as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + if DetourAttach( + mem::transmute(&mut CREATE_PROCESS_W), + ZludaCreateProcessW as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + if DetourAttach( + mem::transmute(&mut CREATE_PROCESS_AS_USER_W), + ZludaCreateProcessAsUserW as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + if DetourAttach( + mem::transmute(&mut CREATE_PROCESS_WITH_LOGON_W), + ZludaCreateProcessWithLogonW as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + if DetourAttach( + mem::transmute(&mut CREATE_PROCESS_WITH_TOKEN_W), + ZludaCreateProcessWithTokenW as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + true +} + +#[must_use] +unsafe fn detach_create_process() -> bool { + if DetourDetach( + mem::transmute(&mut CREATE_PROCESS_A), + ZludaCreateProcessA as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + if DetourDetach( + mem::transmute(&mut CREATE_PROCESS_W), + ZludaCreateProcessW as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + if DetourDetach( + mem::transmute(&mut CREATE_PROCESS_AS_USER_W), + ZludaCreateProcessAsUserW as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + if DetourDetach( + mem::transmute(&mut CREATE_PROCESS_WITH_LOGON_W), + ZludaCreateProcessWithLogonW as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + if DetourDetach( + mem::transmute(&mut CREATE_PROCESS_WITH_TOKEN_W), + ZludaCreateProcessWithTokenW as *mut _, + ) != NO_ERROR as i32 + { + return false; + } + true +} |