aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-02-20 21:40:19 +0100
committerGitHub <[email protected]>2021-02-20 21:40:19 +0100
commit36514bd6ebcc22fde93e1fc52b3e336da0683bfb (patch)
tree2bf3da943fddb7c4029fe9bd234ad1ceac00e8ab
parent972f612562dc534ad605bfc5a00bc908ddd8b3de (diff)
downloadZLUDA-36514bd6ebcc22fde93e1fc52b3e336da0683bfb.tar.gz
ZLUDA-36514bd6ebcc22fde93e1fc52b3e336da0683bfb.zip
Improve ZLUDA injection (#37)
Improve injector&redirector so it's no longer required to manually mess with files if the application links nvcuda.dll. Additionally inject into child processes
-rw-r--r--README.md7
-rw-r--r--zluda_dump/Cargo.toml2
-rw-r--r--zluda_inject/Cargo.toml2
-rw-r--r--zluda_redirect/Cargo.toml2
-rw-r--r--zluda_redirect/src/lib.rs730
5 files changed, 664 insertions, 79 deletions
diff --git a/README.md b/README.md
index a6ea202..fe1674f 100644
--- a/README.md
+++ b/README.md
@@ -52,13 +52,16 @@ Overall, ZLUDA is slower in GeekBench by roughly 2%.
### Windows
You should have the most recent Intel GPU drivers installed.\
-Copy `nvcuda.dll` to the application directory (the directory where .exe file is) and launch it normally
+Run your application like this:
+```
+<ZLUDA_DIRECTORY>\zluda_with.exe -- <APPLICATION> <APPLICATIONS_ARGUMENTS>
+```
### Linux
A very recent version of [compute-runtime](https://github.com/intel/compute-runtime) and [Level Zero loader](https://github.com/oneapi-src/level-zero/releases) is required. At the time of the writing 20.45.18403 is the oldest recommended version.
Run your application like this:
```
-LD_LIBRARY_PATH=<PATH_TO_THE_DIRECTORY_WITH_ZLUDA_PROVIDED_LIBCUDA> <YOUR_APPLICATION>
+LD_LIBRARY_PATH=<ZLUDA_DIRECTORY> <APPLICATION> <APPLICATIONS_ARGUMENTS>
```
## Building
diff --git a/zluda_dump/Cargo.toml b/zluda_dump/Cargo.toml
index b81ef13..80c7ddc 100644
--- a/zluda_dump/Cargo.toml
+++ b/zluda_dump/Cargo.toml
@@ -14,7 +14,7 @@ lz4-sys = "1.9"
regex = "1.4"
[target.'cfg(windows)'.dependencies]
-winapi = { version = "0.3", features = ["libloaderapi", "debugapi"] }
+winapi = { version = "0.3", features = ["libloaderapi", "debugapi", "std"] }
wchar = "0.6"
detours-sys = { path = "../detours-sys" }
diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml
index 7576e08..1181a21 100644
--- a/zluda_inject/Cargo.toml
+++ b/zluda_inject/Cargo.toml
@@ -9,5 +9,5 @@ name = "zluda_with"
path = "src/main.rs"
[target.'cfg(windows)'.dependencies]
-winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "std", "synchapi", "winbase"] }
+winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] }
detours-sys = { path = "../detours-sys" }
diff --git a/zluda_redirect/Cargo.toml b/zluda_redirect/Cargo.toml
index 1b6b958..130e244 100644
--- a/zluda_redirect/Cargo.toml
+++ b/zluda_redirect/Cargo.toml
@@ -10,4 +10,4 @@ crate-type = ["cdylib"]
[target.'cfg(windows)'.dependencies]
detours-sys = { path = "../detours-sys" }
wchar = "0.6"
-winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "std"] } \ No newline at end of file
+winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "std"] } \ No newline at end of file
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs
index 04b2413..5de7530 100644
--- a/zluda_redirect/src/lib.rs
+++ b/zluda_redirect/src/lib.rs
@@ -3,20 +3,50 @@
extern crate detours_sys;
extern crate winapi;
-use std::{mem, ptr, slice};
+use std::{
+ ffi::c_void,
+ mem,
+ os::raw::{c_int, c_uint, c_ulong},
+ ptr, slice, usize,
+};
use detours_sys::{
- DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionBegin,
- DetourTransactionCommit, DetourUpdateThread,
+ DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionAbort,
+ DetourTransactionBegin, DetourTransactionCommit, DetourUpdateProcessWithDll,
+ DetourUpdateThread,
};
use wchar::wch;
-use winapi::um::processthreadsapi::GetCurrentThread;
-use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR};
+use winapi::{
+ shared::minwindef::{BOOL, LPVOID},
+ um::{
+ handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
+ minwinbase::LPSECURITY_ATTRIBUTES,
+ processthreadsapi::{
+ CreateProcessA, GetCurrentProcessId, GetCurrentThread, GetCurrentThreadId, OpenThread,
+ ResumeThread, SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA,
+ LPSTARTUPINFOW,
+ },
+ tlhelp32::{
+ CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
+ },
+ winbase::CREATE_SUSPENDED,
+ winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
+ },
+};
use winapi::{
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
};
use winapi::{
+ shared::minwindef::{FARPROC, HINSTANCE},
+ um::{
+ libloaderapi::{GetModuleFileNameA, GetProcAddress},
+ processthreadsapi::{CreateProcessAsUserW, CreateProcessW},
+ winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW},
+ winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR},
+ },
+};
+use winapi::{
shared::winerror::NO_ERROR,
um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
};
@@ -27,6 +57,12 @@ const NVCUDA_UTF8: &'static str = "NVCUDA.DLL";
const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL");
static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new();
static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
+static mut DETACH_LOAD_LIBRARY: bool = false;
+static mut NVCUDA_ORIGINAL_MODULE: HMODULE = ptr::null_mut();
+static mut CUINIT_ORIGINAL_FN: FARPROC = ptr::null_mut();
+static mut CURRENT_MODULE_FILENAME: Vec<u8> = Vec::new();
+const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801;
+const CUDA_ERROR_UNKNOWN: c_uint = 999;
static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE =
LoadLibraryA;
@@ -40,6 +76,72 @@ static mut LOAD_LIBRARY_EX_A: unsafe extern "system" fn(
dwFlags: DWORD,
) -> HMODULE = LoadLibraryExA;
+static mut CREATE_PROCESS_A: unsafe extern "system" fn(
+ lpApplicationName: LPCSTR,
+ lpCommandLine: LPSTR,
+ lpProcessAttributes: LPSECURITY_ATTRIBUTES,
+ lpThreadAttributes: LPSECURITY_ATTRIBUTES,
+ bInheritHandles: BOOL,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCSTR,
+ lpStartupInfo: LPSTARTUPINFOA,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL = CreateProcessA;
+
+static mut CREATE_PROCESS_W: unsafe extern "system" fn(
+ lpApplicationName: LPCWSTR,
+ lpCommandLine: LPWSTR,
+ lpProcessAttributes: LPSECURITY_ATTRIBUTES,
+ lpThreadAttributes: LPSECURITY_ATTRIBUTES,
+ bInheritHandles: BOOL,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCWSTR,
+ lpStartupInfo: LPSTARTUPINFOW,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL = CreateProcessW;
+
+static mut CREATE_PROCESS_AS_USER_W: unsafe extern "system" fn(
+ hToken: HANDLE,
+ lpApplicationName: LPCWSTR,
+ lpCommandLine: LPWSTR,
+ lpProcessAttributes: LPSECURITY_ATTRIBUTES,
+ lpThreadAttributes: LPSECURITY_ATTRIBUTES,
+ bInheritHandles: BOOL,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCWSTR,
+ lpStartupInfo: LPSTARTUPINFOW,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL = CreateProcessAsUserW;
+
+static mut CREATE_PROCESS_WITH_TOKEN_W: unsafe extern "system" fn(
+ hToken: HANDLE,
+ dwLogonFlags: DWORD,
+ lpApplicationName: LPCWSTR,
+ lpCommandLine: LPWSTR,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCWSTR,
+ lpStartupInfo: LPSTARTUPINFOW,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL = CreateProcessWithTokenW;
+
+static mut CREATE_PROCESS_WITH_LOGON_W: unsafe extern "system" fn(
+ lpUsername: LPCWSTR,
+ lpDomain: LPCWSTR,
+ lpPassword: LPCWSTR,
+ dwLogonFlags: DWORD,
+ lpApplicationName: LPCWSTR,
+ lpCommandLine: LPWSTR,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCWSTR,
+ lpStartupInfo: LPSTARTUPINFOW,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL = CreateProcessWithLogonW;
+
static mut LOAD_LIBRARY_EX_W: unsafe extern "system" fn(
lpLibFileName: LPCWSTR,
hFile: HANDLE,
@@ -100,6 +202,293 @@ unsafe extern "system" fn ZludaLoadLibraryExW(
(LOAD_LIBRARY_EX_W)(nvcuda_file_name, hFile, dwFlags)
}
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaCreateProcessA(
+ lpApplicationName: LPCSTR,
+ lpCommandLine: LPSTR,
+ lpProcessAttributes: LPSECURITY_ATTRIBUTES,
+ lpThreadAttributes: LPSECURITY_ATTRIBUTES,
+ bInheritHandles: BOOL,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCSTR,
+ lpStartupInfo: LPSTARTUPINFOA,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL {
+ let create_proc_result = CREATE_PROCESS_A(
+ lpApplicationName,
+ lpCommandLine,
+ lpProcessAttributes,
+ lpThreadAttributes,
+ bInheritHandles,
+ dwCreationFlags | CREATE_SUSPENDED,
+ lpEnvironment,
+ lpCurrentDirectory,
+ lpStartupInfo,
+ lpProcessInformation,
+ );
+ continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
+}
+
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaCreateProcessW(
+ lpApplicationName: LPCWSTR,
+ lpCommandLine: LPWSTR,
+ lpProcessAttributes: LPSECURITY_ATTRIBUTES,
+ lpThreadAttributes: LPSECURITY_ATTRIBUTES,
+ bInheritHandles: BOOL,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCWSTR,
+ lpStartupInfo: LPSTARTUPINFOW,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL {
+ let create_proc_result = CREATE_PROCESS_W(
+ lpApplicationName,
+ lpCommandLine,
+ lpProcessAttributes,
+ lpThreadAttributes,
+ bInheritHandles,
+ dwCreationFlags | CREATE_SUSPENDED,
+ lpEnvironment,
+ lpCurrentDirectory,
+ lpStartupInfo,
+ lpProcessInformation,
+ );
+ continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
+}
+
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaCreateProcessAsUserW(
+ hToken: HANDLE,
+ lpApplicationName: LPCWSTR,
+ lpCommandLine: LPWSTR,
+ lpProcessAttributes: LPSECURITY_ATTRIBUTES,
+ lpThreadAttributes: LPSECURITY_ATTRIBUTES,
+ bInheritHandles: BOOL,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCWSTR,
+ lpStartupInfo: LPSTARTUPINFOW,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL {
+ let create_proc_result = CREATE_PROCESS_AS_USER_W(
+ hToken,
+ lpApplicationName,
+ lpCommandLine,
+ lpProcessAttributes,
+ lpThreadAttributes,
+ bInheritHandles,
+ dwCreationFlags | CREATE_SUSPENDED,
+ lpEnvironment,
+ lpCurrentDirectory,
+ lpStartupInfo,
+ lpProcessInformation,
+ );
+ continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
+}
+
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaCreateProcessWithLogonW(
+ lpUsername: LPCWSTR,
+ lpDomain: LPCWSTR,
+ lpPassword: LPCWSTR,
+ dwLogonFlags: DWORD,
+ lpApplicationName: LPCWSTR,
+ lpCommandLine: LPWSTR,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCWSTR,
+ lpStartupInfo: LPSTARTUPINFOW,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL {
+ let create_proc_result = CREATE_PROCESS_WITH_LOGON_W(
+ lpUsername,
+ lpDomain,
+ lpPassword,
+ dwLogonFlags,
+ lpApplicationName,
+ lpCommandLine,
+ dwCreationFlags | CREATE_SUSPENDED,
+ lpEnvironment,
+ lpCurrentDirectory,
+ lpStartupInfo,
+ lpProcessInformation,
+ );
+ continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
+}
+
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaCreateProcessWithTokenW(
+ hToken: HANDLE,
+ dwLogonFlags: DWORD,
+ lpApplicationName: LPCWSTR,
+ lpCommandLine: LPWSTR,
+ dwCreationFlags: DWORD,
+ lpEnvironment: LPVOID,
+ lpCurrentDirectory: LPCWSTR,
+ lpStartupInfo: LPSTARTUPINFOW,
+ lpProcessInformation: LPPROCESS_INFORMATION,
+) -> BOOL {
+ let create_proc_result = CREATE_PROCESS_WITH_TOKEN_W(
+ hToken,
+ dwLogonFlags,
+ lpApplicationName,
+ lpCommandLine,
+ dwCreationFlags,
+ lpEnvironment,
+ lpCurrentDirectory,
+ lpStartupInfo,
+ lpProcessInformation,
+ );
+ continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
+}
+
+unsafe fn continue_create_process_hook(
+ create_proc_result: BOOL,
+ creation_flags: DWORD,
+ process_information: LPPROCESS_INFORMATION,
+) -> BOOL {
+ if create_proc_result == 0 {
+ return 0;
+ }
+ if DetourUpdateProcessWithDll(
+ (*process_information).hProcess,
+ &mut CURRENT_MODULE_FILENAME.as_ptr() as *mut _ as *mut _,
+ 1,
+ ) == 0
+ {
+ TerminateProcess((*process_information).hProcess, 1);
+ return 0;
+ }
+ if detours_sys::DetourCopyPayloadToProcess(
+ (*process_information).hProcess,
+ &PAYLOAD_GUID,
+ ZLUDA_PATH_UTF16.unwrap().as_ptr() as *mut _,
+ (ZLUDA_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32,
+ ) == FALSE
+ {
+ TerminateProcess((*process_information).hProcess, 1);
+ return 0;
+ }
+
+ if creation_flags & CREATE_SUSPENDED == 0 {
+ if ResumeThread((*process_information).hThread) == -1i32 as u32 {
+ TerminateProcess((*process_information).hProcess, 1);
+ return 0;
+ }
+ }
+ create_proc_result
+}
+
+unsafe extern "C" fn cuinit_detour(flags: c_uint) -> c_uint {
+ let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr());
+ if zluda_module == ptr::null_mut() {
+ return CUDA_ERROR_UNKNOWN;
+ }
+ let suspended_threads = suspend_all_threads_except_current();
+ let suspended_threads = match suspended_threads {
+ Some(t) => t,
+ None => return CUDA_ERROR_UNKNOWN,
+ };
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ resume_threads(&suspended_threads);
+ return CUDA_ERROR_UNKNOWN;
+ }
+ for t in suspended_threads.iter() {
+ if DetourUpdateThread(*t) != NO_ERROR as i32 {
+ DetourTransactionAbort();
+ resume_threads(&suspended_threads);
+ return CUDA_ERROR_UNKNOWN;
+ }
+ }
+ if detours_sys::DetourEnumerateExports(
+ NVCUDA_ORIGINAL_MODULE as *mut _,
+ &zluda_module as *const _ as *mut _,
+ Some(override_nvcuda_export),
+ ) == FALSE
+ {
+ DetourTransactionAbort();
+ resume_threads(&suspended_threads);
+ return CUDA_ERROR_UNKNOWN;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ DetourTransactionAbort();
+ resume_threads(&suspended_threads);
+ return CUDA_ERROR_UNKNOWN;
+ }
+ resume_threads(&suspended_threads);
+ let zluda_cuinit = GetProcAddress(zluda_module, b"cuInit\0".as_ptr() as *const _);
+ (mem::transmute::<_, unsafe extern "C" fn(c_uint) -> c_uint>(zluda_cuinit))(flags)
+}
+
+unsafe fn suspend_all_threads_except_current() -> Option<Vec<*mut c_void>> {
+ let thread_snap = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0);
+ if thread_snap == INVALID_HANDLE_VALUE {
+ return None;
+ }
+ let current_thread = GetCurrentThreadId();
+ let current_process = GetCurrentProcessId();
+ let mut threads = Vec::new();
+ let mut thread = mem::zeroed::<THREADENTRY32>();
+ thread.dwSize = mem::size_of::<THREADENTRY32>() as u32;
+ if Thread32First(thread_snap, &mut thread) == 0 {
+ CloseHandle(thread_snap);
+ return None;
+ }
+ loop {
+ if thread.th32OwnerProcessID == current_process && thread.th32ThreadID != current_thread {
+ let thread_handle = OpenThread(THREAD_SUSPEND_RESUME, 0, thread.th32ThreadID);
+ if thread_handle == ptr::null_mut() {
+ CloseHandle(thread_snap);
+ resume_threads(&threads);
+ return None;
+ }
+ if SuspendThread(thread_handle) == (-1i32 as u32) {
+ CloseHandle(thread_snap);
+ resume_threads(&threads);
+ return None;
+ }
+ threads.push(thread_handle);
+ }
+ if Thread32Next(thread_snap, &mut thread) == 0 {
+ break;
+ }
+ }
+ CloseHandle(thread_snap);
+ Some(threads)
+}
+
+unsafe fn resume_threads(threads: &[*mut c_void]) {
+ for t in threads {
+ ResumeThread(*t);
+ CloseHandle(*t);
+ }
+}
+
+unsafe extern "C" fn override_nvcuda_export(
+ context_ptr: *mut c_void,
+ _: c_ulong,
+ name: LPCSTR,
+ mut address: *mut c_void,
+) -> c_int {
+ let zluda_module: HMODULE = *(context_ptr as *mut HMODULE);
+ let mut zluda_fn = GetProcAddress(zluda_module, name);
+ if zluda_fn == ptr::null_mut() {
+ // We only support 64 bits and in all relevant calling conventions stack
+ // is caller-cleaned, so probably we will not crash
+ zluda_fn = unsupported_cuda_fn as *mut _;
+ }
+ if DetourAttach((&mut address) as *mut _, zluda_fn as *mut _) != NO_ERROR as i32 {
+ return FALSE;
+ }
+ TRUE
+}
+
+unsafe extern "C" fn unsupported_cuda_fn() -> c_uint {
+ CUDA_ERROR_NOT_SUPPORTED
+}
+
fn is_nvcuda_dll_utf8(lib: *const u8) -> bool {
is_nvcuda_dll(lib, 0, NVCUDA_UTF8.as_bytes(), |c| {
if c >= 'a' as u8 && c <= 'z' as u8 {
@@ -147,94 +536,207 @@ fn is_nvcuda_dll<T: Copy + PartialEq>(
#[allow(non_snake_case)]
#[no_mangle]
-unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) -> i32 {
+unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u8) -> i32 {
if dwReason == DLL_PROCESS_ATTACH {
if DetourRestoreAfterWith() == FALSE {
return FALSE;
}
+ if !initialize_current_module_name(instDLL) {
+ return FALSE;
+ }
match get_zluda_dll_path() {
Some(path) => {
ZLUDA_PATH_UTF16 = Some(path);
+ // from_utf16_lossy(...) handles terminating NULL correctly
ZLUDA_PATH_UTF8 = String::from_utf16_lossy(path).into_bytes();
}
None => return FALSE,
}
- if DetourTransactionBegin() != NO_ERROR as i32 {
- return FALSE;
- }
- if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 {
- return FALSE;
- }
- if DetourAttach(
- mem::transmute(&mut LOAD_LIBRARY_A),
- ZludaLoadLibraryA as *mut _,
- ) != NO_ERROR as i32
- {
- return FALSE;
- }
- if DetourAttach(
- mem::transmute(&mut LOAD_LIBRARY_W),
- ZludaLoadLibraryW as *mut _,
- ) != NO_ERROR as i32
- {
- return FALSE;
- }
- if DetourAttach(
- mem::transmute(&mut LOAD_LIBRARY_EX_A),
- ZludaLoadLibraryExA as *mut _,
- ) != NO_ERROR as i32
- {
- return FALSE;
- }
- if DetourAttach(
- mem::transmute(&mut LOAD_LIBRARY_EX_W),
- ZludaLoadLibraryExW as *mut _,
- ) != NO_ERROR as i32
- {
- return FALSE;
- }
- if DetourTransactionCommit() != NO_ERROR as i32 {
- return FALSE;
+ // If the application (directly or not) links to nvcuda.dll, nvcuda.dll
+ // will get loaded before we can act. In this case, instead of
+ // redirecting LoadLibrary* to load ZLUDA, we redirect cuInit to
+ // a cuInit implementation that will load ZLUDA and set up detouts.
+ // We can't do it here because LoadLibrary* inside DllMain is illegal.
+ // We greatly prefer wholesale redirecting inside LoadLibrary*.
+ // Hooking inside cuInit is brittle in the face of multiple
+ // threads (DetourUpdateThread)
+ match get_cuinit() {
+ Some((nvcuda_mod, cuinit_fn)) => attach_cuinit(nvcuda_mod, cuinit_fn),
+ None => attach_load_libary(),
}
} else if dwReason == DLL_PROCESS_DETACH {
- if DetourTransactionBegin() != NO_ERROR as i32 {
- return FALSE;
- }
- if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 {
- return FALSE;
- }
- if DetourDetach(
- mem::transmute(&mut LOAD_LIBRARY_A),
- ZludaLoadLibraryA as *mut _,
- ) != NO_ERROR as i32
- {
- return FALSE;
+ if DETACH_LOAD_LIBRARY {
+ detach_load_library()
+ } else {
+ detach_cuinit()
}
- if DetourDetach(
- mem::transmute(&mut LOAD_LIBRARY_W),
- ZludaLoadLibraryW as *mut _,
- ) != NO_ERROR as i32
- {
- return FALSE;
+ } else {
+ TRUE
+ }
+}
+
+#[must_use]
+unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool {
+ let mut name = vec![0; 128 as usize];
+ loop {
+ let size = GetModuleFileNameA(
+ current_module,
+ name.as_mut_ptr() as *mut _,
+ name.len() as u32,
+ );
+ if size == 0 {
+ return false;
}
- if DetourDetach(
- mem::transmute(&mut LOAD_LIBRARY_EX_A),
- ZludaLoadLibraryExA as *mut _,
- ) != NO_ERROR as i32
- {
- return FALSE;
+ if size < name.len() as u32 {
+ name.truncate(size as usize);
+ CURRENT_MODULE_FILENAME = name;
+ return true;
}
- if DetourDetach(
- mem::transmute(&mut LOAD_LIBRARY_EX_W),
- ZludaLoadLibraryExW as *mut _,
- ) != NO_ERROR as i32
- {
- return FALSE;
+ name.resize(name.len() * 2, 0);
+ }
+}
+
+unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> {
+ let mut module = ptr::null_mut();
+ loop {
+ module = detours_sys::DetourEnumerateModules(module);
+ if module == ptr::null_mut() {
+ return None;
}
- if DetourTransactionCommit() != NO_ERROR as i32 {
- return FALSE;
+ let cuinit_addr = GetProcAddress(module as *mut _, b"cuInit\0".as_ptr() as *const _);
+ if cuinit_addr != ptr::null_mut() {
+ return Some((module as *mut _, cuinit_addr));
}
}
+}
+
+#[must_use]
+unsafe fn attach_cuinit(nvcuda_mod: HMODULE, mut cuinit: FARPROC) -> i32 {
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if !attach_create_process() {
+ return FALSE;
+ }
+ NVCUDA_ORIGINAL_MODULE = nvcuda_mod;
+ CUINIT_ORIGINAL_FN = cuinit;
+ if DetourAttach(mem::transmute(&mut cuinit), cuinit_detour as *mut _) != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ TRUE
+}
+
+#[must_use]
+unsafe fn detach_cuinit() -> i32 {
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if !detach_create_process() {
+ return FALSE;
+ }
+ if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut CUINIT_ORIGINAL_FN),
+ cuinit_detour as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ TRUE
+}
+
+#[must_use]
+unsafe fn attach_load_libary() -> i32 {
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if !attach_create_process() {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_A),
+ ZludaLoadLibraryA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_W),
+ ZludaLoadLibraryW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_A),
+ ZludaLoadLibraryExA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_W),
+ ZludaLoadLibraryExW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ TRUE
+}
+
+#[must_use]
+unsafe fn detach_load_library() -> i32 {
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if !detach_create_process() {
+ return FALSE;
+ }
+ if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_A),
+ ZludaLoadLibraryA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_W),
+ ZludaLoadLibraryW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_A),
+ ZludaLoadLibraryExA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_W),
+ ZludaLoadLibraryExW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ return FALSE;
+ }
TRUE
}
@@ -258,3 +760,83 @@ fn get_zluda_dll_path() -> Option<&'static [u16]> {
}
None
}
+
+#[must_use]
+unsafe fn attach_create_process() -> bool {
+ if DetourAttach(
+ mem::transmute(&mut CREATE_PROCESS_A),
+ ZludaCreateProcessA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ if DetourAttach(
+ mem::transmute(&mut CREATE_PROCESS_W),
+ ZludaCreateProcessW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ if DetourAttach(
+ mem::transmute(&mut CREATE_PROCESS_AS_USER_W),
+ ZludaCreateProcessAsUserW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ if DetourAttach(
+ mem::transmute(&mut CREATE_PROCESS_WITH_LOGON_W),
+ ZludaCreateProcessWithLogonW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ if DetourAttach(
+ mem::transmute(&mut CREATE_PROCESS_WITH_TOKEN_W),
+ ZludaCreateProcessWithTokenW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ true
+}
+
+#[must_use]
+unsafe fn detach_create_process() -> bool {
+ if DetourDetach(
+ mem::transmute(&mut CREATE_PROCESS_A),
+ ZludaCreateProcessA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ if DetourDetach(
+ mem::transmute(&mut CREATE_PROCESS_W),
+ ZludaCreateProcessW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ if DetourDetach(
+ mem::transmute(&mut CREATE_PROCESS_AS_USER_W),
+ ZludaCreateProcessAsUserW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ if DetourDetach(
+ mem::transmute(&mut CREATE_PROCESS_WITH_LOGON_W),
+ ZludaCreateProcessWithLogonW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ if DetourDetach(
+ mem::transmute(&mut CREATE_PROCESS_WITH_TOKEN_W),
+ ZludaCreateProcessWithTokenW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return false;
+ }
+ true
+}