aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_redirect/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_redirect/src/lib.rs')
-rw-r--r--zluda_redirect/src/lib.rs155
1 files changed, 127 insertions, 28 deletions
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs
index 657500c..0f73ced 100644
--- a/zluda_redirect/src/lib.rs
+++ b/zluda_redirect/src/lib.rs
@@ -12,15 +12,27 @@ use std::{
};
use detours_sys::{
- DetourAttach, DetourEnumerateExports, DetourGetEntryPoint, DetourRestoreAfterWith,
- DetourTransactionAbort, DetourTransactionBegin, DetourTransactionCommit,
- DetourUpdateProcessWithDll, DetourUpdateThread,
+ DetourAllocateRegionWithinJumpBounds, DetourAttach, DetourEnumerateExports,
+ DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
+ DetourTransactionCommit, DetourUpdateProcessWithDll, DetourUpdateThread,
};
+use goblin::pe::{
+ self,
+ header::{CoffHeader, DOS_MAGIC, PE_MAGIC, PE_POINTER_OFFSET},
+ optional_header::StandardFields64,
+};
+use memoffset::offset_of;
use tempfile::TempDir;
use wchar::wch;
use winapi::{
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
- um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
+ um::{
+ libloaderapi::{GetModuleHandleA, LoadLibraryExA},
+ memoryapi::VirtualProtect,
+ processthreadsapi::{FlushInstructionCache, GetCurrentProcess},
+ sysinfoapi::GetSystemInfo,
+ winnt::{LPCSTR, PAGE_READWRITE},
+ },
};
use winapi::{
shared::minwindef::{FARPROC, HINSTANCE},
@@ -380,19 +392,27 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW(
continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
}
-static mut MAIN: unsafe extern "system" fn() -> DWORD = ZludaMain;
+static mut MAIN: unsafe extern "system" fn() -> DWORD = zluda_main;
+static mut COR_EXE_MAIN: unsafe extern "system" fn() -> DWORD = zluda_main_clr;
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications
// "If a DLL with the same module name is already loaded in memory, the system
// uses the loaded DLL, no matter which directory it is in. The system does not
// search for the DLL."
-#[allow(non_snake_case)]
-unsafe extern "system" fn ZludaMain() -> DWORD {
+unsafe extern "system" fn zluda_main() -> DWORD {
+ zluda_main_impl(MAIN)
+}
+
+unsafe extern "system" fn zluda_main_clr() -> DWORD {
+ zluda_main_impl(COR_EXE_MAIN)
+}
+
+unsafe fn zluda_main_impl(original: unsafe extern "system" fn() -> DWORD) -> DWORD {
let temp_dir = match do_zluda_preload() {
Ok(f) => f,
- Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as u32,
+ Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as DWORD,
};
- let result = MAIN();
+ let result = original();
drop(temp_dir);
result
}
@@ -768,17 +788,15 @@ unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool {
}
unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> {
- let mut module = ptr::null_mut();
- loop {
- module = detours_sys::DetourEnumerateModules(module);
- if module == ptr::null_mut() {
- return None;
- }
- let cuinit_addr = GetProcAddress(module as *mut _, b"cuInit\0".as_ptr() as *const _);
- if cuinit_addr != ptr::null_mut() {
- return Some((module as *mut _, cuinit_addr));
- }
+ let nvcuda = GetModuleHandleA(b"nvcuda\0".as_ptr() as _);
+ if nvcuda == ptr::null_mut() {
+ return None;
+ }
+ let cuinit_addr = GetProcAddress(nvcuda, b"cuInit\0".as_ptr() as _);
+ if cuinit_addr == ptr::null_mut() {
+ return None;
}
+ Some((nvcuda as *mut _, cuinit_addr))
}
#[must_use]
@@ -850,17 +868,10 @@ unsafe extern "stdcall" fn gather_imports_impl(
#[must_use]
unsafe fn detour_main() -> Option<DetourDetachGuard> {
- let exe_handle = GetModuleHandleW(ptr::null());
- let entry_point = DetourGetEntryPoint(exe_handle as _);
- if entry_point == ptr::null_mut() {
+ if !override_entry_point() {
return None;
}
- MAIN = mem::transmute(entry_point);
- let detour_functions = vec![
- (
- &mut MAIN as *mut _ as *mut *mut c_void,
- ZludaMain as *mut c_void,
- ),
+ let mut detour_functions = vec![
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
@@ -875,9 +886,97 @@ unsafe fn detour_main() -> Option<DetourDetachGuard> {
ZludaLoadLibraryExW as _,
),
];
+ detour_functions.extend(get_clr_entry_point());
DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new())
}
+unsafe fn override_entry_point() -> bool {
+ let exe_handle = GetModuleHandleW(ptr::null());
+ let dos_signature = exe_handle as *mut u16;
+ if *dos_signature != DOS_MAGIC {
+ return false;
+ }
+ let pe_offset = *((exe_handle as *mut u8).add(PE_POINTER_OFFSET as usize) as *mut u32);
+ let pe_sig = (exe_handle as *mut u8).add(pe_offset as usize) as *mut u32;
+ if (*pe_sig) != PE_MAGIC {
+ return false;
+ }
+ let coff_header = pe_sig.add(1) as *mut CoffHeader;
+ let standard_coff_fields = coff_header.add(1) as *mut StandardFields64;
+ if (*standard_coff_fields).magic != pe::optional_header::MAGIC_64 {
+ return false;
+ }
+ let entry_point = mem::transmute::<_, unsafe extern "system" fn() -> DWORD>(
+ (exe_handle as *mut u8).add((*standard_coff_fields).address_of_entry_point as usize),
+ );
+ let mut allocated_size = 0;
+ let exe_region = DetourAllocateRegionWithinJumpBounds(exe_handle as _, &mut allocated_size);
+ if (allocated_size as usize) < mem::size_of::<JmpThunk64>() || exe_region == ptr::null_mut() {
+ return false;
+ }
+ MAIN = entry_point;
+ *(exe_region as *mut JmpThunk64) = JmpThunk64::new(zluda_main);
+ FlushInstructionCache(
+ GetCurrentProcess(),
+ exe_region,
+ mem::size_of::<JmpThunk64>(),
+ );
+ let new_address_of_entry_point = (exe_region as *mut u8).offset_from(exe_handle as *mut u8);
+ let entry_point_offset = offset_of!(StandardFields64, address_of_entry_point);
+ let mut system_info = mem::zeroed();
+ GetSystemInfo(&mut system_info);
+ let pointer_to_address_of_entry_point =
+ (standard_coff_fields as *mut u8).add(entry_point_offset) as *mut i32;
+ let page_size = system_info.dwPageSize as usize;
+ let page_start = (((pointer_to_address_of_entry_point as usize) / page_size) * page_size) as _;
+ let mut old_protect = 0;
+ if VirtualProtect(page_start, page_size, PAGE_READWRITE, &mut old_protect) == 0 {
+ return false;
+ }
+ *pointer_to_address_of_entry_point = new_address_of_entry_point as i32;
+ if VirtualProtect(page_start, page_size, old_protect, &mut old_protect) == 0 {
+ return false;
+ }
+ true
+}
+
+// mov rax, $address;
+// jmp rax;
+// int 3;
+#[repr(packed)]
+#[allow(dead_code)]
+#[cfg(target_pointer_width = "64")]
+struct JmpThunk64 {
+ mov_rax: [u8; 2],
+ address: u64,
+ jmp_rax: [u8; 2],
+ int3: u8,
+}
+
+impl JmpThunk64 {
+ fn new<T: Sized>(target: unsafe extern "system" fn() -> T) -> Self {
+ JmpThunk64 {
+ mov_rax: [0x48, 0xB8],
+ address: target as u64,
+ jmp_rax: [0xFF, 0xE0],
+ int3: 0xcc,
+ }
+ }
+}
+
+unsafe fn get_clr_entry_point() -> Option<(*mut *mut c_void, *mut c_void)> {
+ let mscoree = GetModuleHandleA(b"mscoree\0".as_ptr() as _);
+ if mscoree == ptr::null_mut() {
+ return None;
+ }
+ let proc = GetProcAddress(mscoree, b"_CorExeMain\0".as_ptr() as _);
+ if proc == ptr::null_mut() {
+ return None;
+ }
+ COR_EXE_MAIN = mem::transmute(proc);
+ Some((&mut COR_EXE_MAIN as *mut _ as _, zluda_main_clr as _))
+}
+
fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> {
match get_payload(&PAYLOAD_NVCUDA_GUID) {
None => None,