diff options
author | Andrzej Janik <[email protected]> | 2021-01-03 18:45:48 +0100 |
---|---|---|
committer | GitHub <[email protected]> | 2021-01-03 18:45:48 +0100 |
commit | 2c0e9b912fe341bd1e513614014fa43b666d257d (patch) | |
tree | b5d3aa00a5192230657792833450848ceb557a1a /zluda_redirect | |
parent | 659b2c6ec431c3f1103e700a20da4c66467aa35d (diff) | |
download | ZLUDA-2c0e9b912fe341bd1e513614014fa43b666d257d.tar.gz ZLUDA-2c0e9b912fe341bd1e513614014fa43b666d257d.zip |
Fix Windows ZLUDA injector (#26)
Fix various bugs in injector and redirector, make them more robust and enable building them by default
Diffstat (limited to 'zluda_redirect')
-rw-r--r-- | zluda_redirect/Cargo.toml | 3 | ||||
-rw-r--r-- | zluda_redirect/src/lib.rs | 255 | ||||
-rw-r--r-- | zluda_redirect/src/payload_guid.rs | 6 |
3 files changed, 209 insertions, 55 deletions
diff --git a/zluda_redirect/Cargo.toml b/zluda_redirect/Cargo.toml index 46069bc..1b6b958 100644 --- a/zluda_redirect/Cargo.toml +++ b/zluda_redirect/Cargo.toml @@ -8,7 +8,6 @@ edition = "2018" crate-type = ["cdylib"] [target.'cfg(windows)'.dependencies] -detours-sys = "0.1" +detours-sys = { path = "../detours-sys" } wchar = "0.6" -guid = "0.1" winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "std"] }
\ No newline at end of file diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index c6cd4be..d0497a3 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -1,105 +1,254 @@ -#![cfg(windows)] +#![cfg(target_os = "windows")] extern crate detours_sys; -#[macro_use] -extern crate guid; extern crate winapi; -use std::mem; +use std::{mem, ptr, slice}; use detours_sys::{ DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionBegin, DetourTransactionCommit, DetourUpdateThread, }; -use wchar::{wch, wch_c}; -use winapi::shared::minwindef::{DWORD, FALSE, HMODULE, TRUE}; -use winapi::um::libloaderapi::LoadLibraryExW; +use wchar::wch; use winapi::um::processthreadsapi::GetCurrentThread; -use winapi::um::winbase::lstrcmpiW; use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR}; +use winapi::{ + shared::minwindef::{DWORD, FALSE, HMODULE, TRUE}, + um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR}, +}; +use winapi::{ + shared::winerror::NO_ERROR, + um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW}, +}; + +include!("payload_guid.rs"); + +const NVCUDA_UTF8: &'static str = "NVCUDA.DLL"; +const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL"); +static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new(); +static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None; -const NVCUDA_PATH: &[u16] = wch_c!(r"C:\WINDOWS\system32\nvcuda.dll"); -const ZLUDA_DLL: &[u16] = wch!(r"nvcuda.dll"); -static mut ZLUDA_PATH: Option<Vec<u16>> = None; +static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE = + LoadLibraryA; + +static mut LOAD_LIBRARY_W: unsafe extern "system" fn(lpLibFileName: LPCWSTR) -> HMODULE = + LoadLibraryW; + +static mut LOAD_LIBRARY_EX_A: unsafe extern "system" fn( + lpLibFileName: LPCSTR, + hFile: HANDLE, + dwFlags: DWORD, +) -> HMODULE = LoadLibraryExA; -static mut LOAD_LIBRARY_EX: unsafe extern "system" fn( +static mut LOAD_LIBRARY_EX_W: unsafe extern "system" fn( lpLibFileName: LPCWSTR, hFile: HANDLE, dwFlags: DWORD, ) -> HMODULE = LoadLibraryExW; #[allow(non_snake_case)] -#[no_mangle] +unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE { + let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) { + ZLUDA_PATH_UTF8.as_ptr() as *const _ + } else { + lpLibFileName + }; + (LOAD_LIBRARY_A)(nvcuda_file_name) +} + +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE { + let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) { + ZLUDA_PATH_UTF16.unwrap().as_ptr() + } else { + lpLibFileName + }; + (LOAD_LIBRARY_W)(nvcuda_file_name) +} + +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaLoadLibraryExA( + lpLibFileName: LPCSTR, + hFile: HANDLE, + dwFlags: DWORD, +) -> HMODULE { + let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) { + ZLUDA_PATH_UTF8.as_ptr() as *const _ + } else { + lpLibFileName + }; + (LOAD_LIBRARY_EX_A)(nvcuda_file_name, hFile, dwFlags) +} + +#[allow(non_snake_case)] unsafe extern "system" fn ZludaLoadLibraryExW( lpLibFileName: LPCWSTR, hFile: HANDLE, dwFlags: DWORD, ) -> HMODULE { - let nvcuda_file_name = if lstrcmpiW(lpLibFileName, NVCUDA_PATH.as_ptr()) == 0 { - ZLUDA_PATH.as_ref().unwrap().as_ptr() + let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) { + ZLUDA_PATH_UTF16.unwrap().as_ptr() } else { lpLibFileName }; - (LOAD_LIBRARY_EX)(nvcuda_file_name, hFile, dwFlags) + (LOAD_LIBRARY_EX_W)(nvcuda_file_name, hFile, dwFlags) +} + +fn is_nvcuda_dll_utf8(lib: *const u8) -> bool { + is_nvcuda_dll(lib, 0, NVCUDA_UTF8.as_bytes(), |c| { + if c >= 'a' as u8 && c <= 'z' as u8 { + c - 32 + } else { + c + } + }) +} +fn is_nvcuda_dll_utf16(lib: *const u16) -> bool { + is_nvcuda_dll(lib, 0u16, NVCUDA_UTF16, |c| { + if c >= 'a' as u16 && c <= 'z' as u16 { + c - 32 + } else { + c + } + }) +} + +fn is_nvcuda_dll<T: Copy + PartialEq>( + lib: *const T, + zero: T, + dll_name: &[T], + uppercase: impl Fn(T) -> T, +) -> bool { + let mut len = 0; + loop { + if unsafe { *lib.offset(len) } == zero { + break; + } + len += 1; + } + if (len as usize) < dll_name.len() { + return false; + } + let slice = + unsafe { slice::from_raw_parts(lib.offset(len - dll_name.len() as isize), dll_name.len()) }; + for i in 0..dll_name.len() { + if uppercase(slice[i]) != dll_name[i] { + return false; + } + } + true } #[allow(non_snake_case)] #[no_mangle] unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) -> i32 { if dwReason == DLL_PROCESS_ATTACH { - DetourRestoreAfterWith(); + if DetourRestoreAfterWith() == FALSE { + return FALSE; + } match get_zluda_dll_path() { - Some((path, len)) => set_zluda_dll_path(path, len), + Some(path) => { + ZLUDA_PATH_UTF16 = Some(path); + ZLUDA_PATH_UTF8 = String::from_utf16_lossy(path).into_bytes(); + } None => return FALSE, } - DetourTransactionBegin(); - DetourUpdateThread(GetCurrentThread()); - DetourAttach( - std::mem::transmute(&mut LOAD_LIBRARY_EX), + if DetourTransactionBegin() != NO_ERROR as i32 { + return FALSE; + } + if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 { + return FALSE; + } + if DetourAttach( + mem::transmute(&mut LOAD_LIBRARY_A), + ZludaLoadLibraryA as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourAttach( + mem::transmute(&mut LOAD_LIBRARY_W), + ZludaLoadLibraryW as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourAttach( + mem::transmute(&mut LOAD_LIBRARY_EX_A), + ZludaLoadLibraryExA as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourAttach( + mem::transmute(&mut LOAD_LIBRARY_EX_W), ZludaLoadLibraryExW as *mut _, - ); - DetourTransactionCommit(); + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourTransactionCommit() != NO_ERROR as i32 { + return FALSE; + } } else if dwReason == DLL_PROCESS_DETACH { - DetourTransactionBegin(); - DetourUpdateThread(GetCurrentThread()); - DetourDetach( - std::mem::transmute(&mut LOAD_LIBRARY_EX), + if DetourTransactionBegin() != NO_ERROR as i32 { + return FALSE; + } + if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut LOAD_LIBRARY_A), + ZludaLoadLibraryA as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut LOAD_LIBRARY_W), + ZludaLoadLibraryW as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut LOAD_LIBRARY_EX_A), + ZludaLoadLibraryExA as *mut _, + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourDetach( + mem::transmute(&mut LOAD_LIBRARY_EX_W), ZludaLoadLibraryExW as *mut _, - ); - DetourTransactionCommit(); + ) != NO_ERROR as i32 + { + return FALSE; + } + if DetourTransactionCommit() != NO_ERROR as i32 { + return FALSE; + } } TRUE } -fn get_zluda_dll_path() -> Option<(*const u16, usize)> { - let guid = guid! {"C225FC0C-00D7-40B8-935A-7E342A9344C1"}; - let mut module = std::ptr::null_mut(); +fn get_zluda_dll_path() -> Option<&'static [u16]> { + let mut module = ptr::null_mut(); loop { module = unsafe { detours_sys::DetourEnumerateModules(module) }; - if module == std::ptr::null_mut() { + if module == ptr::null_mut() { break; } let mut size = 0; - let payload = unsafe { - detours_sys::DetourFindPayload(module, std::mem::transmute(&guid), &mut size) - }; - if payload != std::ptr::null_mut() { - return Some((payload as *const _, (size as usize) / mem::size_of::<u16>())); + let payload = unsafe { detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size) }; + if payload != ptr::null_mut() { + return unsafe { + Some(slice::from_raw_parts( + payload as *const _, + (size as usize) / mem::size_of::<u16>(), + )) + }; } } None } - -unsafe fn set_zluda_dll_path(path: *const u16, len: usize) { - let len = len as usize; - let mut result = Vec::<u16>::with_capacity(len + ZLUDA_DLL.len() + 2); - for i in 0..len { - result.push(*path.add(i)); - } - result.push(0x5c); // \ - for c in ZLUDA_DLL.iter().copied() { - result.push(c); - } - result.push(0); - ZLUDA_PATH = Some(result); -} diff --git a/zluda_redirect/src/payload_guid.rs b/zluda_redirect/src/payload_guid.rs new file mode 100644 index 0000000..eaf021d --- /dev/null +++ b/zluda_redirect/src/payload_guid.rs @@ -0,0 +1,6 @@ +const PAYLOAD_GUID: detours_sys::GUID = detours_sys::GUID {
+ Data1: 0xC225FC0C,
+ Data2: 0x00D7,
+ Data3: 0x40B8,
+ Data4: [0x93, 0x5A, 0x7E, 0x34, 0x2A, 0x93, 0x44, 0xC1],
+};
\ No newline at end of file |