aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_redirect
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-01-03 18:45:48 +0100
committerGitHub <[email protected]>2021-01-03 18:45:48 +0100
commit2c0e9b912fe341bd1e513614014fa43b666d257d (patch)
treeb5d3aa00a5192230657792833450848ceb557a1a /zluda_redirect
parent659b2c6ec431c3f1103e700a20da4c66467aa35d (diff)
downloadZLUDA-2c0e9b912fe341bd1e513614014fa43b666d257d.tar.gz
ZLUDA-2c0e9b912fe341bd1e513614014fa43b666d257d.zip
Fix Windows ZLUDA injector (#26)
Fix various bugs in injector and redirector, make them more robust and enable building them by default
Diffstat (limited to 'zluda_redirect')
-rw-r--r--zluda_redirect/Cargo.toml3
-rw-r--r--zluda_redirect/src/lib.rs255
-rw-r--r--zluda_redirect/src/payload_guid.rs6
3 files changed, 209 insertions, 55 deletions
diff --git a/zluda_redirect/Cargo.toml b/zluda_redirect/Cargo.toml
index 46069bc..1b6b958 100644
--- a/zluda_redirect/Cargo.toml
+++ b/zluda_redirect/Cargo.toml
@@ -8,7 +8,6 @@ edition = "2018"
crate-type = ["cdylib"]
[target.'cfg(windows)'.dependencies]
-detours-sys = "0.1"
+detours-sys = { path = "../detours-sys" }
wchar = "0.6"
-guid = "0.1"
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "std"] } \ No newline at end of file
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs
index c6cd4be..d0497a3 100644
--- a/zluda_redirect/src/lib.rs
+++ b/zluda_redirect/src/lib.rs
@@ -1,105 +1,254 @@
-#![cfg(windows)]
+#![cfg(target_os = "windows")]
extern crate detours_sys;
-#[macro_use]
-extern crate guid;
extern crate winapi;
-use std::mem;
+use std::{mem, ptr, slice};
use detours_sys::{
DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionBegin,
DetourTransactionCommit, DetourUpdateThread,
};
-use wchar::{wch, wch_c};
-use winapi::shared::minwindef::{DWORD, FALSE, HMODULE, TRUE};
-use winapi::um::libloaderapi::LoadLibraryExW;
+use wchar::wch;
use winapi::um::processthreadsapi::GetCurrentThread;
-use winapi::um::winbase::lstrcmpiW;
use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR};
+use winapi::{
+ shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
+ um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
+};
+use winapi::{
+ shared::winerror::NO_ERROR,
+ um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
+};
+
+include!("payload_guid.rs");
+
+const NVCUDA_UTF8: &'static str = "NVCUDA.DLL";
+const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL");
+static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new();
+static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
-const NVCUDA_PATH: &[u16] = wch_c!(r"C:\WINDOWS\system32\nvcuda.dll");
-const ZLUDA_DLL: &[u16] = wch!(r"nvcuda.dll");
-static mut ZLUDA_PATH: Option<Vec<u16>> = None;
+static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE =
+ LoadLibraryA;
+
+static mut LOAD_LIBRARY_W: unsafe extern "system" fn(lpLibFileName: LPCWSTR) -> HMODULE =
+ LoadLibraryW;
+
+static mut LOAD_LIBRARY_EX_A: unsafe extern "system" fn(
+ lpLibFileName: LPCSTR,
+ hFile: HANDLE,
+ dwFlags: DWORD,
+) -> HMODULE = LoadLibraryExA;
-static mut LOAD_LIBRARY_EX: unsafe extern "system" fn(
+static mut LOAD_LIBRARY_EX_W: unsafe extern "system" fn(
lpLibFileName: LPCWSTR,
hFile: HANDLE,
dwFlags: DWORD,
) -> HMODULE = LoadLibraryExW;
#[allow(non_snake_case)]
-#[no_mangle]
+unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
+ let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
+ ZLUDA_PATH_UTF8.as_ptr() as *const _
+ } else {
+ lpLibFileName
+ };
+ (LOAD_LIBRARY_A)(nvcuda_file_name)
+}
+
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE {
+ let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
+ ZLUDA_PATH_UTF16.unwrap().as_ptr()
+ } else {
+ lpLibFileName
+ };
+ (LOAD_LIBRARY_W)(nvcuda_file_name)
+}
+
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaLoadLibraryExA(
+ lpLibFileName: LPCSTR,
+ hFile: HANDLE,
+ dwFlags: DWORD,
+) -> HMODULE {
+ let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
+ ZLUDA_PATH_UTF8.as_ptr() as *const _
+ } else {
+ lpLibFileName
+ };
+ (LOAD_LIBRARY_EX_A)(nvcuda_file_name, hFile, dwFlags)
+}
+
+#[allow(non_snake_case)]
unsafe extern "system" fn ZludaLoadLibraryExW(
lpLibFileName: LPCWSTR,
hFile: HANDLE,
dwFlags: DWORD,
) -> HMODULE {
- let nvcuda_file_name = if lstrcmpiW(lpLibFileName, NVCUDA_PATH.as_ptr()) == 0 {
- ZLUDA_PATH.as_ref().unwrap().as_ptr()
+ let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
+ ZLUDA_PATH_UTF16.unwrap().as_ptr()
} else {
lpLibFileName
};
- (LOAD_LIBRARY_EX)(nvcuda_file_name, hFile, dwFlags)
+ (LOAD_LIBRARY_EX_W)(nvcuda_file_name, hFile, dwFlags)
+}
+
+fn is_nvcuda_dll_utf8(lib: *const u8) -> bool {
+ is_nvcuda_dll(lib, 0, NVCUDA_UTF8.as_bytes(), |c| {
+ if c >= 'a' as u8 && c <= 'z' as u8 {
+ c - 32
+ } else {
+ c
+ }
+ })
+}
+fn is_nvcuda_dll_utf16(lib: *const u16) -> bool {
+ is_nvcuda_dll(lib, 0u16, NVCUDA_UTF16, |c| {
+ if c >= 'a' as u16 && c <= 'z' as u16 {
+ c - 32
+ } else {
+ c
+ }
+ })
+}
+
+fn is_nvcuda_dll<T: Copy + PartialEq>(
+ lib: *const T,
+ zero: T,
+ dll_name: &[T],
+ uppercase: impl Fn(T) -> T,
+) -> bool {
+ let mut len = 0;
+ loop {
+ if unsafe { *lib.offset(len) } == zero {
+ break;
+ }
+ len += 1;
+ }
+ if (len as usize) < dll_name.len() {
+ return false;
+ }
+ let slice =
+ unsafe { slice::from_raw_parts(lib.offset(len - dll_name.len() as isize), dll_name.len()) };
+ for i in 0..dll_name.len() {
+ if uppercase(slice[i]) != dll_name[i] {
+ return false;
+ }
+ }
+ true
}
#[allow(non_snake_case)]
#[no_mangle]
unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) -> i32 {
if dwReason == DLL_PROCESS_ATTACH {
- DetourRestoreAfterWith();
+ if DetourRestoreAfterWith() == FALSE {
+ return FALSE;
+ }
match get_zluda_dll_path() {
- Some((path, len)) => set_zluda_dll_path(path, len),
+ Some(path) => {
+ ZLUDA_PATH_UTF16 = Some(path);
+ ZLUDA_PATH_UTF8 = String::from_utf16_lossy(path).into_bytes();
+ }
None => return FALSE,
}
- DetourTransactionBegin();
- DetourUpdateThread(GetCurrentThread());
- DetourAttach(
- std::mem::transmute(&mut LOAD_LIBRARY_EX),
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_A),
+ ZludaLoadLibraryA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_W),
+ ZludaLoadLibraryW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_A),
+ ZludaLoadLibraryExA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_W),
ZludaLoadLibraryExW as *mut _,
- );
- DetourTransactionCommit();
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ return FALSE;
+ }
} else if dwReason == DLL_PROCESS_DETACH {
- DetourTransactionBegin();
- DetourUpdateThread(GetCurrentThread());
- DetourDetach(
- std::mem::transmute(&mut LOAD_LIBRARY_EX),
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_A),
+ ZludaLoadLibraryA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_W),
+ ZludaLoadLibraryW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_A),
+ ZludaLoadLibraryExA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_W),
ZludaLoadLibraryExW as *mut _,
- );
- DetourTransactionCommit();
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ return FALSE;
+ }
}
TRUE
}
-fn get_zluda_dll_path() -> Option<(*const u16, usize)> {
- let guid = guid! {"C225FC0C-00D7-40B8-935A-7E342A9344C1"};
- let mut module = std::ptr::null_mut();
+fn get_zluda_dll_path() -> Option<&'static [u16]> {
+ let mut module = ptr::null_mut();
loop {
module = unsafe { detours_sys::DetourEnumerateModules(module) };
- if module == std::ptr::null_mut() {
+ if module == ptr::null_mut() {
break;
}
let mut size = 0;
- let payload = unsafe {
- detours_sys::DetourFindPayload(module, std::mem::transmute(&guid), &mut size)
- };
- if payload != std::ptr::null_mut() {
- return Some((payload as *const _, (size as usize) / mem::size_of::<u16>()));
+ let payload = unsafe { detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size) };
+ if payload != ptr::null_mut() {
+ return unsafe {
+ Some(slice::from_raw_parts(
+ payload as *const _,
+ (size as usize) / mem::size_of::<u16>(),
+ ))
+ };
}
}
None
}
-
-unsafe fn set_zluda_dll_path(path: *const u16, len: usize) {
- let len = len as usize;
- let mut result = Vec::<u16>::with_capacity(len + ZLUDA_DLL.len() + 2);
- for i in 0..len {
- result.push(*path.add(i));
- }
- result.push(0x5c); // \
- for c in ZLUDA_DLL.iter().copied() {
- result.push(c);
- }
- result.push(0);
- ZLUDA_PATH = Some(result);
-}
diff --git a/zluda_redirect/src/payload_guid.rs b/zluda_redirect/src/payload_guid.rs
new file mode 100644
index 0000000..eaf021d
--- /dev/null
+++ b/zluda_redirect/src/payload_guid.rs
@@ -0,0 +1,6 @@
+const PAYLOAD_GUID: detours_sys::GUID = detours_sys::GUID {
+ Data1: 0xC225FC0C,
+ Data2: 0x00D7,
+ Data3: 0x40B8,
+ Data4: [0x93, 0x5A, 0x7E, 0x34, 0x2A, 0x93, 0x44, 0xC1],
+}; \ No newline at end of file