aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--zluda_dump/src/os_win.rs1
-rw-r--r--zluda_inject/Cargo.toml4
-rw-r--r--zluda_redirect/Cargo.toml3
-rw-r--r--zluda_redirect/src/lib.rs133
4 files changed, 112 insertions, 29 deletions
diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs
index 2bfc457..d80f6e6 100644
--- a/zluda_dump/src/os_win.rs
+++ b/zluda_dump/src/os_win.rs
@@ -66,7 +66,6 @@ impl PlatformLibrary {
if module == ptr::null_mut() {
break;
}
- let mut size = 0;
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
if payload != ptr::null_mut() {
return Some(module as _);
diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml
index cc9dc9b..73489bb 100644
--- a/zluda_inject/Cargo.toml
+++ b/zluda_inject/Cargo.toml
@@ -13,7 +13,7 @@ winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchap
detours-sys = { path = "../detours-sys" }
[dev-dependencies]
-# dependency for integration tests
+# all of those are used in integration tests
zluda_redirect = { path = "../zluda_redirect" }
-# dependency for integration tests
zluda_dump = { path = "../zluda_dump" }
+zluda_ml = { path = "../zluda_ml" }
diff --git a/zluda_redirect/Cargo.toml b/zluda_redirect/Cargo.toml
index 85c2976..97fb3e2 100644
--- a/zluda_redirect/Cargo.toml
+++ b/zluda_redirect/Cargo.toml
@@ -10,4 +10,5 @@ crate-type = ["cdylib"]
[target.'cfg(windows)'.dependencies]
detours-sys = { path = "../detours-sys" }
wchar = "0.6"
-winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] } \ No newline at end of file
+winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
+tempfile = "3" \ No newline at end of file
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs
index f2d2739..657500c 100644
--- a/zluda_redirect/src/lib.rs
+++ b/zluda_redirect/src/lib.rs
@@ -6,34 +6,19 @@ extern crate winapi;
use std::{
collections::HashMap,
ffi::{c_void, CStr},
- mem,
+ io, mem,
os::raw::c_uint,
ptr, slice, usize,
};
use detours_sys::{
- DetourAttach, DetourEnumerateExports, DetourRestoreAfterWith, DetourTransactionAbort,
- DetourTransactionBegin, DetourTransactionCommit, DetourUpdateProcessWithDll,
- DetourUpdateThread,
+ DetourAttach, DetourEnumerateExports, DetourGetEntryPoint, DetourRestoreAfterWith,
+ DetourTransactionAbort, DetourTransactionBegin, DetourTransactionCommit,
+ DetourUpdateProcessWithDll, DetourUpdateThread,
};
+use tempfile::TempDir;
use wchar::wch;
use winapi::{
- shared::minwindef::{BOOL, LPVOID},
- um::{
- handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
- minwinbase::LPSECURITY_ATTRIBUTES,
- processthreadsapi::{
- CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
- SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
- },
- tlhelp32::{
- CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
- },
- winbase::CREATE_SUSPENDED,
- winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
- },
-};
-use winapi::{
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
};
@@ -50,6 +35,26 @@ use winapi::{
shared::winerror::NO_ERROR,
um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
};
+use winapi::{
+ shared::{
+ minwindef::{BOOL, LPVOID},
+ winerror::E_UNEXPECTED,
+ },
+ um::{
+ handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
+ libloaderapi::GetModuleHandleW,
+ minwinbase::LPSECURITY_ATTRIBUTES,
+ processthreadsapi::{
+ CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
+ SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
+ },
+ tlhelp32::{
+ CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
+ },
+ winbase::{CopyFileW, CreateSymbolicLinkW, CREATE_SUSPENDED},
+ winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
+ },
+};
include!("payload_guid.rs");
@@ -375,6 +380,59 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW(
continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
}
+static mut MAIN: unsafe extern "system" fn() -> DWORD = ZludaMain;
+
+// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications
+// "If a DLL with the same module name is already loaded in memory, the system
+// uses the loaded DLL, no matter which directory it is in. The system does not
+// search for the DLL."
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaMain() -> DWORD {
+ let temp_dir = match do_zluda_preload() {
+ Ok(f) => f,
+ Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as u32,
+ };
+ let result = MAIN();
+ drop(temp_dir);
+ result
+}
+
+unsafe fn do_zluda_preload() -> std::io::Result<TempDir> {
+ let temp_dir = tempfile::tempdir()?;
+ do_single_zluda_preload(&temp_dir, ZLUDA_PATH_UTF16.unwrap().as_ptr(), NVCUDA_UTF8)?;
+ do_single_zluda_preload(&temp_dir, ZLUDA_ML_PATH_UTF16.unwrap().as_ptr(), NVML_UTF8)?;
+ Ok(temp_dir)
+}
+
+unsafe fn do_single_zluda_preload(
+ temp_dir: &TempDir,
+ full_path: *const u16,
+ file_name: &'static str,
+) -> io::Result<()> {
+ let mut temp_file_path = temp_dir.path().to_path_buf();
+ temp_file_path.push(file_name);
+ let mut temp_file_path_utf16 = temp_file_path
+ .into_os_string()
+ .to_string_lossy()
+ .encode_utf16()
+ .collect::<Vec<_>>();
+ temp_file_path_utf16.push(0);
+ // Probably we are not in developer mode, do a copty then
+ if 0 == CreateSymbolicLinkW(
+ temp_file_path_utf16.as_ptr(),
+ full_path,
+ 0x2, //SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE
+ ) {
+ if 0 == CopyFileW(full_path, temp_file_path_utf16.as_ptr(), 1) {
+ return Err(io::Error::last_os_error());
+ }
+ }
+ if ptr::null_mut() == ZludaLoadLibraryW_NoRedirect(temp_file_path_utf16.as_ptr()) {
+ return Err(io::Error::last_os_error());
+ }
+ Ok(())
+}
+
// This type encapsulates typical calling sequence of detours and cleanup.
// We have two ways we do detours:
// * If we are loaded before nvcuda.dll, we hook LoadLibrary*
@@ -668,8 +726,8 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
// redirecting LoadLibrary* to load ZLUDA, we override already loaded
// functions
let detach_guard = match get_cuinit() {
- Some((nvcuda_mod, _)) => attach_cuinit(nvcuda_mod),
- None => attach_load_libary(),
+ Some((nvcuda_mod, _)) => detour_already_loaded_nvcuda(nvcuda_mod),
+ None => detour_main(),
};
match detach_guard {
Some(g) => {
@@ -724,7 +782,7 @@ unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> {
}
#[must_use]
-unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
+unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr());
if zluda_module == ptr::null_mut() {
return None;
@@ -747,7 +805,22 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
(original_fn_address as _, override_fn_address),
);
}
- DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs)
+ let detour_functions = vec![
+ (
+ &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
+ ZludaLoadLibraryA as *mut c_void,
+ ),
+ (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
+ (
+ &mut LOAD_LIBRARY_EX_A as *mut _ as _,
+ ZludaLoadLibraryExA as _,
+ ),
+ (
+ &mut LOAD_LIBRARY_EX_W as *mut _ as _,
+ ZludaLoadLibraryExW as _,
+ ),
+ ];
+ DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, override_fn_pairs)
}
unsafe extern "system" fn cuda_unsupported() -> c_uint {
@@ -776,9 +849,19 @@ unsafe extern "stdcall" fn gather_imports_impl(
}
#[must_use]
-unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
+unsafe fn detour_main() -> Option<DetourDetachGuard> {
+ let exe_handle = GetModuleHandleW(ptr::null());
+ let entry_point = DetourGetEntryPoint(exe_handle as _);
+ if entry_point == ptr::null_mut() {
+ return None;
+ }
+ MAIN = mem::transmute(entry_point);
let detour_functions = vec![
(
+ &mut MAIN as *mut _ as *mut *mut c_void,
+ ZludaMain as *mut c_void,
+ ),
+ (
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
),