diff options
Diffstat (limited to 'zluda_redirect/src/lib.rs')
-rw-r--r-- | zluda_redirect/src/lib.rs | 155 |
1 files changed, 127 insertions, 28 deletions
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index 657500c..0f73ced 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -12,15 +12,27 @@ use std::{ }; use detours_sys::{ - DetourAttach, DetourEnumerateExports, DetourGetEntryPoint, DetourRestoreAfterWith, - DetourTransactionAbort, DetourTransactionBegin, DetourTransactionCommit, - DetourUpdateProcessWithDll, DetourUpdateThread, + DetourAllocateRegionWithinJumpBounds, DetourAttach, DetourEnumerateExports, + DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin, + DetourTransactionCommit, DetourUpdateProcessWithDll, DetourUpdateThread, }; +use goblin::pe::{ + self, + header::{CoffHeader, DOS_MAGIC, PE_MAGIC, PE_POINTER_OFFSET}, + optional_header::StandardFields64, +}; +use memoffset::offset_of; use tempfile::TempDir; use wchar::wch; use winapi::{ shared::minwindef::{DWORD, FALSE, HMODULE, TRUE}, - um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR}, + um::{ + libloaderapi::{GetModuleHandleA, LoadLibraryExA}, + memoryapi::VirtualProtect, + processthreadsapi::{FlushInstructionCache, GetCurrentProcess}, + sysinfoapi::GetSystemInfo, + winnt::{LPCSTR, PAGE_READWRITE}, + }, }; use winapi::{ shared::minwindef::{FARPROC, HINSTANCE}, @@ -380,19 +392,27 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW( continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation) } -static mut MAIN: unsafe extern "system" fn() -> DWORD = ZludaMain; +static mut MAIN: unsafe extern "system" fn() -> DWORD = zluda_main; +static mut COR_EXE_MAIN: unsafe extern "system" fn() -> DWORD = zluda_main_clr; // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications // "If a DLL with the same module name is already loaded in memory, the system // uses the loaded DLL, no matter which directory it is in. The system does not // search for the DLL." -#[allow(non_snake_case)] -unsafe extern "system" fn ZludaMain() -> DWORD { +unsafe extern "system" fn zluda_main() -> DWORD { + zluda_main_impl(MAIN) +} + +unsafe extern "system" fn zluda_main_clr() -> DWORD { + zluda_main_impl(COR_EXE_MAIN) +} + +unsafe fn zluda_main_impl(original: unsafe extern "system" fn() -> DWORD) -> DWORD { let temp_dir = match do_zluda_preload() { Ok(f) => f, - Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as u32, + Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as DWORD, }; - let result = MAIN(); + let result = original(); drop(temp_dir); result } @@ -768,17 +788,15 @@ unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool { } unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> { - let mut module = ptr::null_mut(); - loop { - module = detours_sys::DetourEnumerateModules(module); - if module == ptr::null_mut() { - return None; - } - let cuinit_addr = GetProcAddress(module as *mut _, b"cuInit\0".as_ptr() as *const _); - if cuinit_addr != ptr::null_mut() { - return Some((module as *mut _, cuinit_addr)); - } + let nvcuda = GetModuleHandleA(b"nvcuda\0".as_ptr() as _); + if nvcuda == ptr::null_mut() { + return None; + } + let cuinit_addr = GetProcAddress(nvcuda, b"cuInit\0".as_ptr() as _); + if cuinit_addr == ptr::null_mut() { + return None; } + Some((nvcuda as *mut _, cuinit_addr)) } #[must_use] @@ -850,17 +868,10 @@ unsafe extern "stdcall" fn gather_imports_impl( #[must_use] unsafe fn detour_main() -> Option<DetourDetachGuard> { - let exe_handle = GetModuleHandleW(ptr::null()); - let entry_point = DetourGetEntryPoint(exe_handle as _); - if entry_point == ptr::null_mut() { + if !override_entry_point() { return None; } - MAIN = mem::transmute(entry_point); - let detour_functions = vec![ - ( - &mut MAIN as *mut _ as *mut *mut c_void, - ZludaMain as *mut c_void, - ), + let mut detour_functions = vec![ ( &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, ZludaLoadLibraryA as *mut c_void, @@ -875,9 +886,97 @@ unsafe fn detour_main() -> Option<DetourDetachGuard> { ZludaLoadLibraryExW as _, ), ]; + detour_functions.extend(get_clr_entry_point()); DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new()) } +unsafe fn override_entry_point() -> bool { + let exe_handle = GetModuleHandleW(ptr::null()); + let dos_signature = exe_handle as *mut u16; + if *dos_signature != DOS_MAGIC { + return false; + } + let pe_offset = *((exe_handle as *mut u8).add(PE_POINTER_OFFSET as usize) as *mut u32); + let pe_sig = (exe_handle as *mut u8).add(pe_offset as usize) as *mut u32; + if (*pe_sig) != PE_MAGIC { + return false; + } + let coff_header = pe_sig.add(1) as *mut CoffHeader; + let standard_coff_fields = coff_header.add(1) as *mut StandardFields64; + if (*standard_coff_fields).magic != pe::optional_header::MAGIC_64 { + return false; + } + let entry_point = mem::transmute::<_, unsafe extern "system" fn() -> DWORD>( + (exe_handle as *mut u8).add((*standard_coff_fields).address_of_entry_point as usize), + ); + let mut allocated_size = 0; + let exe_region = DetourAllocateRegionWithinJumpBounds(exe_handle as _, &mut allocated_size); + if (allocated_size as usize) < mem::size_of::<JmpThunk64>() || exe_region == ptr::null_mut() { + return false; + } + MAIN = entry_point; + *(exe_region as *mut JmpThunk64) = JmpThunk64::new(zluda_main); + FlushInstructionCache( + GetCurrentProcess(), + exe_region, + mem::size_of::<JmpThunk64>(), + ); + let new_address_of_entry_point = (exe_region as *mut u8).offset_from(exe_handle as *mut u8); + let entry_point_offset = offset_of!(StandardFields64, address_of_entry_point); + let mut system_info = mem::zeroed(); + GetSystemInfo(&mut system_info); + let pointer_to_address_of_entry_point = + (standard_coff_fields as *mut u8).add(entry_point_offset) as *mut i32; + let page_size = system_info.dwPageSize as usize; + let page_start = (((pointer_to_address_of_entry_point as usize) / page_size) * page_size) as _; + let mut old_protect = 0; + if VirtualProtect(page_start, page_size, PAGE_READWRITE, &mut old_protect) == 0 { + return false; + } + *pointer_to_address_of_entry_point = new_address_of_entry_point as i32; + if VirtualProtect(page_start, page_size, old_protect, &mut old_protect) == 0 { + return false; + } + true +} + +// mov rax, $address; +// jmp rax; +// int 3; +#[repr(packed)] +#[allow(dead_code)] +#[cfg(target_pointer_width = "64")] +struct JmpThunk64 { + mov_rax: [u8; 2], + address: u64, + jmp_rax: [u8; 2], + int3: u8, +} + +impl JmpThunk64 { + fn new<T: Sized>(target: unsafe extern "system" fn() -> T) -> Self { + JmpThunk64 { + mov_rax: [0x48, 0xB8], + address: target as u64, + jmp_rax: [0xFF, 0xE0], + int3: 0xcc, + } + } +} + +unsafe fn get_clr_entry_point() -> Option<(*mut *mut c_void, *mut c_void)> { + let mscoree = GetModuleHandleA(b"mscoree\0".as_ptr() as _); + if mscoree == ptr::null_mut() { + return None; + } + let proc = GetProcAddress(mscoree, b"_CorExeMain\0".as_ptr() as _); + if proc == ptr::null_mut() { + return None; + } + COR_EXE_MAIN = mem::transmute(proc); + Some((&mut COR_EXE_MAIN as *mut _ as _, zluda_main_clr as _)) +} + fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> { match get_payload(&PAYLOAD_NVCUDA_GUID) { None => None, |