aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-01-02 22:01:07 +0100
committerAndrzej Janik <[email protected]>2021-01-03 17:56:13 +0100
commit49a0ea377c703a364947f2791524616908c9537a (patch)
tree14d005b0c10c1f87691bae78e49effa6c9a36ffc
parent639d1255e9928dbae2dfe71a0b8225504b23b382 (diff)
downloadZLUDA-49a0ea377c703a364947f2791524616908c9537a.tar.gz
ZLUDA-49a0ea377c703a364947f2791524616908c9537a.zip
Make redirection DLL more robust, redirect more calls
-rw-r--r--zluda_inject/Cargo.toml1
-rw-r--r--zluda_inject/src/bin.rs190
-rw-r--r--zluda_inject/src/main.rs2
-rw-r--r--zluda_redirect/src/lib.rs242
4 files changed, 291 insertions, 144 deletions
diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml
index 1f10178..fe551eb 100644
--- a/zluda_inject/Cargo.toml
+++ b/zluda_inject/Cargo.toml
@@ -12,4 +12,3 @@ path = "src/main.rs"
winapi = { version = "0.3", features = ["processthreadsapi", "std", "synchapi"] }
detours-sys = { path = "../detours-sys" }
clap = "2.33"
-wstr = "0.2" \ No newline at end of file
diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs
index 83975e1..6aa3e5f 100644
--- a/zluda_inject/src/bin.rs
+++ b/zluda_inject/src/bin.rs
@@ -1,65 +1,32 @@
-use std::ffi::OsStr;
-use std::os::windows::ffi::OsStrExt;
-use std::os::windows::ffi::OsStringExt;
+use std::env;
+use std::env::Args;
+use std::mem;
+use std::path::Path;
use std::ptr;
use std::{error::Error, process};
-use std::{ffi::OsString, mem};
use winapi::um::{
- libloaderapi::GetModuleFileNameW,
processthreadsapi::{GetExitCodeProcess, ResumeThread},
synchapi::WaitForSingleObject,
};
-use winapi::{shared::minwindef, um::errhandlingapi::GetLastError};
-use winapi::{
- shared::winerror::ERROR_INSUFFICIENT_BUFFER,
- um::winbase::{INFINITE, WAIT_FAILED},
-};
-use clap::{App, AppSettings, Arg};
+use winapi::um::winbase::{INFINITE, WAIT_FAILED};
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
-static ZLUDA_DLL: &'static [u16] = wstr!("nvcuda.dll");
+static ZLUDA_DLL: &'static str = "nvcuda.dll";
include!("../../zluda_redirect/src/payload_guid.rs");
pub fn main_impl() -> Result<(), Box<dyn Error>> {
- let matches = App::new("ZLUDA injector")
- .setting(AppSettings::TrailingVarArg)
- .arg(
- Arg::with_name("EXE")
- .help("Path to the executable to be injected with ZLUDA")
- .required(true),
- )
- .arg(
- Arg::with_name("ARGS")
- .multiple(true)
- .help("Arguments that will be passed to <EXE>"),
- )
- .get_matches();
- let exe = matches.value_of_os("EXE").unwrap();
- let args: Vec<&OsStr> = matches
- .values_of_os("ARGS")
- .map(|x| x.collect())
- .unwrap_or_else(|| Vec::new());
- let mut cmd_line = Vec::<u16>::with_capacity(exe.len() + 2);
- cmd_line.push('\"' as u16);
- copy_to(exe, &mut cmd_line);
- cmd_line.push('\"' as u16);
- cmd_line.push(' ' as u16);
- args.split_last().map(|(last_arg, args)| {
- for arg in args {
- cmd_line.reserve(arg.len());
- copy_to(arg, &mut cmd_line);
- cmd_line.push(' ' as u16);
- }
- copy_to(last_arg, &mut cmd_line);
- });
-
- cmd_line.push(0);
- let mut injector_path = get_injector_path()?;
- trim_to_parent(&mut injector_path);
- let redirect_path = create_redirect_path(&injector_path);
+ let args = env::args();
+ if args.len() == 0 {
+ print_help();
+ process::exit(1);
+ }
+ let mut cmd_line = construct_command_line(args);
+ let injector_path = env::current_exe()?;
+ let injector_dir = injector_path.parent().unwrap();
+ let redirect_path = create_redirect_path(injector_dir);
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
os_call!(
@@ -79,7 +46,7 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
),
|x| x != 0
);
- let mut zluda_path = create_zluda_path(injector_path);
+ let mut zluda_path = create_zluda_path(injector_dir);
os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
@@ -90,6 +57,7 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|x| x != 0
);
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
+ // TODO: kill the child process if we were killed
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
!= WAIT_FAILED);
let mut child_exit_code: u32 = 0;
@@ -100,56 +68,88 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
process::exit(child_exit_code as i32)
}
-fn trim_to_parent(injector_path: &mut Vec<u16>) {
- let slash_idx = injector_path
- .iter()
- .enumerate()
- .rev()
- .find_map(|(idx, char)| {
- if *char == '/' as u16 || *char == '\\' as u16 {
- Some(idx)
- } else {
- None
- }
- });
- if let Some(idx) = slash_idx {
- injector_path.truncate(idx + 1);
- }
-}
-
-fn create_redirect_path(injector_dir: &[u16]) -> Vec<u8> {
- let os_string: OsString = OsString::from_wide(injector_dir);
- let mut utf8_string = os_string.to_string_lossy().as_bytes().to_vec();
- utf8_string.extend(REDIRECT_DLL.as_bytes());
- utf8_string.push(0);
- utf8_string
+fn print_help() {
+ println!(
+ "USAGE:
+ zluda <EXE> [ARGS]...
+ARGS:
+ <EXE> Path to the executable to be injected with ZLUDA
+ <ARGS>... Arguments that will be passed to <EXE>
+"
+ );
}
-fn create_zluda_path(mut injector_dir: Vec<u16>) -> Vec<u16> {
- injector_dir.extend(ZLUDA_DLL);
- injector_dir
+// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
+fn construct_command_line(args: Args) -> Vec<u16> {
+ let mut cmd_line = Vec::new();
+ let args_len = args.len();
+ for (idx, arg) in args.enumerate().skip(1) {
+ if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
+ cmd_line.extend(arg.encode_utf16());
+ } else {
+ cmd_line.push('"' as u16); // "
+ let mut char_iter = arg.chars().peekable();
+ loop {
+ let mut current = char_iter.next();
+ let mut backslashes = 0;
+ match current {
+ Some('\\') => {
+ while let Some('\\') = char_iter.peek() {
+ backslashes += 1;
+ char_iter.next();
+ }
+ current = char_iter.next();
+ }
+ _ => {}
+ }
+ match current {
+ None => {
+ for _ in 0..(backslashes * 2) {
+ cmd_line.push('\\' as u16);
+ }
+ break;
+ }
+ Some('"') => {
+ for _ in 0..(backslashes * 2 + 1) {
+ cmd_line.push('\\' as u16);
+ }
+ cmd_line.push('"' as u16);
+ }
+ Some(c) => {
+ for _ in 0..backslashes {
+ cmd_line.push('\\' as u16);
+ }
+ let mut temp = [0u16; 2];
+ cmd_line.extend(&*c.encode_utf16(&mut temp));
+ }
+ }
+ }
+ cmd_line.push('"' as u16);
+ }
+ if idx < args_len - 1 {
+ cmd_line.push(' ' as u16);
+ }
+ }
+ cmd_line.push(0);
+ cmd_line
}
-fn copy_to(from: &OsStr, to: &mut Vec<u16>) {
- for x in from.encode_wide() {
- to.push(x);
- }
+fn create_redirect_path(injector_dir: &Path) -> Vec<u8> {
+ let mut injector_dir = injector_dir.to_path_buf();
+ injector_dir.push(REDIRECT_DLL);
+ let mut result = injector_dir.to_string_lossy().into_owned().into_bytes();
+ result.push(0);
+ result
}
-fn get_injector_path() -> Result<Vec<u16>, Box<dyn Error>> {
- let mut result = vec![0u16; minwindef::MAX_PATH];
- let mut copied;
- loop {
- copied = os_call!(
- GetModuleFileNameW(ptr::null_mut(), result.as_mut_ptr(), result.len() as u32),
- |x| x != 0
- );
- if copied != result.len() as u32 {
- break;
- }
- os_call!(GetLastError(), |x| x != ERROR_INSUFFICIENT_BUFFER);
- result.resize(result.len() * 2, 0);
- }
- result.truncate(copied as usize);
- Ok(result)
+fn create_zluda_path(injector_dir: &Path) -> Vec<u16> {
+ let mut injector_dir = injector_dir.to_path_buf();
+ injector_dir.push(ZLUDA_DLL);
+ let mut result = injector_dir
+ .to_string_lossy()
+ .as_ref()
+ .encode_utf16()
+ .collect::<Vec<_>>();
+ result.push(0);
+ result
}
diff --git a/zluda_inject/src/main.rs b/zluda_inject/src/main.rs
index 198c391..e7b1176 100644
--- a/zluda_inject/src/main.rs
+++ b/zluda_inject/src/main.rs
@@ -1,8 +1,6 @@
extern crate clap;
extern crate detours_sys;
extern crate winapi;
-#[macro_use]
-extern crate wstr;
#[macro_use]
#[cfg(target_os = "windows")]
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs
index 0f66b35..d0497a3 100644
--- a/zluda_redirect/src/lib.rs
+++ b/zluda_redirect/src/lib.rs
@@ -3,75 +3,236 @@
extern crate detours_sys;
extern crate winapi;
-use std::{mem, ptr};
+use std::{mem, ptr, slice};
use detours_sys::{
DetourAttach, DetourDetach, DetourRestoreAfterWith, DetourTransactionBegin,
DetourTransactionCommit, DetourUpdateThread,
};
-use wchar::{wch, wch_c};
-use winapi::shared::minwindef::{DWORD, FALSE, HMODULE, TRUE};
-use winapi::um::libloaderapi::LoadLibraryExW;
+use wchar::wch;
use winapi::um::processthreadsapi::GetCurrentThread;
-use winapi::um::winbase::lstrcmpiW;
use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR};
+use winapi::{
+ shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
+ um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
+};
+use winapi::{
+ shared::winerror::NO_ERROR,
+ um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
+};
include!("payload_guid.rs");
-const NVCUDA_PATH: &[u16] = wch_c!(r"C:\WINDOWS\system32\nvcuda.dll");
-const ZLUDA_DLL: &[u16] = wch!(r"nvcuda.dll");
-static mut ZLUDA_PATH: Option<Vec<u16>> = None;
+const NVCUDA_UTF8: &'static str = "NVCUDA.DLL";
+const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL");
+static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new();
+static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
+
+static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE =
+ LoadLibraryA;
+
+static mut LOAD_LIBRARY_W: unsafe extern "system" fn(lpLibFileName: LPCWSTR) -> HMODULE =
+ LoadLibraryW;
-static mut LOAD_LIBRARY_EX: unsafe extern "system" fn(
+static mut LOAD_LIBRARY_EX_A: unsafe extern "system" fn(
+ lpLibFileName: LPCSTR,
+ hFile: HANDLE,
+ dwFlags: DWORD,
+) -> HMODULE = LoadLibraryExA;
+
+static mut LOAD_LIBRARY_EX_W: unsafe extern "system" fn(
lpLibFileName: LPCWSTR,
hFile: HANDLE,
dwFlags: DWORD,
) -> HMODULE = LoadLibraryExW;
#[allow(non_snake_case)]
-#[no_mangle]
+unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
+ let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
+ ZLUDA_PATH_UTF8.as_ptr() as *const _
+ } else {
+ lpLibFileName
+ };
+ (LOAD_LIBRARY_A)(nvcuda_file_name)
+}
+
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE {
+ let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
+ ZLUDA_PATH_UTF16.unwrap().as_ptr()
+ } else {
+ lpLibFileName
+ };
+ (LOAD_LIBRARY_W)(nvcuda_file_name)
+}
+
+#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaLoadLibraryExA(
+ lpLibFileName: LPCSTR,
+ hFile: HANDLE,
+ dwFlags: DWORD,
+) -> HMODULE {
+ let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
+ ZLUDA_PATH_UTF8.as_ptr() as *const _
+ } else {
+ lpLibFileName
+ };
+ (LOAD_LIBRARY_EX_A)(nvcuda_file_name, hFile, dwFlags)
+}
+
+#[allow(non_snake_case)]
unsafe extern "system" fn ZludaLoadLibraryExW(
lpLibFileName: LPCWSTR,
hFile: HANDLE,
dwFlags: DWORD,
) -> HMODULE {
- let nvcuda_file_name = if lstrcmpiW(lpLibFileName, NVCUDA_PATH.as_ptr()) == 0 {
- ZLUDA_PATH.as_ref().unwrap().as_ptr()
+ let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
+ ZLUDA_PATH_UTF16.unwrap().as_ptr()
} else {
lpLibFileName
};
- (LOAD_LIBRARY_EX)(nvcuda_file_name, hFile, dwFlags)
+ (LOAD_LIBRARY_EX_W)(nvcuda_file_name, hFile, dwFlags)
+}
+
+fn is_nvcuda_dll_utf8(lib: *const u8) -> bool {
+ is_nvcuda_dll(lib, 0, NVCUDA_UTF8.as_bytes(), |c| {
+ if c >= 'a' as u8 && c <= 'z' as u8 {
+ c - 32
+ } else {
+ c
+ }
+ })
+}
+fn is_nvcuda_dll_utf16(lib: *const u16) -> bool {
+ is_nvcuda_dll(lib, 0u16, NVCUDA_UTF16, |c| {
+ if c >= 'a' as u16 && c <= 'z' as u16 {
+ c - 32
+ } else {
+ c
+ }
+ })
+}
+
+fn is_nvcuda_dll<T: Copy + PartialEq>(
+ lib: *const T,
+ zero: T,
+ dll_name: &[T],
+ uppercase: impl Fn(T) -> T,
+) -> bool {
+ let mut len = 0;
+ loop {
+ if unsafe { *lib.offset(len) } == zero {
+ break;
+ }
+ len += 1;
+ }
+ if (len as usize) < dll_name.len() {
+ return false;
+ }
+ let slice =
+ unsafe { slice::from_raw_parts(lib.offset(len - dll_name.len() as isize), dll_name.len()) };
+ for i in 0..dll_name.len() {
+ if uppercase(slice[i]) != dll_name[i] {
+ return false;
+ }
+ }
+ true
}
#[allow(non_snake_case)]
#[no_mangle]
unsafe extern "system" fn DllMain(_: *const u8, dwReason: u32, _: *const u8) -> i32 {
if dwReason == DLL_PROCESS_ATTACH {
- DetourRestoreAfterWith();
+ if DetourRestoreAfterWith() == FALSE {
+ return FALSE;
+ }
match get_zluda_dll_path() {
- Some((path, len)) => set_zluda_dll_path(path, len),
+ Some(path) => {
+ ZLUDA_PATH_UTF16 = Some(path);
+ ZLUDA_PATH_UTF8 = String::from_utf16_lossy(path).into_bytes();
+ }
None => return FALSE,
}
- DetourTransactionBegin();
- DetourUpdateThread(GetCurrentThread());
- DetourAttach(
- mem::transmute(&mut LOAD_LIBRARY_EX),
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_A),
+ ZludaLoadLibraryA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_W),
+ ZludaLoadLibraryW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_A),
+ ZludaLoadLibraryExA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourAttach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_W),
ZludaLoadLibraryExW as *mut _,
- );
- DetourTransactionCommit();
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ return FALSE;
+ }
} else if dwReason == DLL_PROCESS_DETACH {
- DetourTransactionBegin();
- DetourUpdateThread(GetCurrentThread());
- DetourDetach(
- mem::transmute(&mut LOAD_LIBRARY_EX),
+ if DetourTransactionBegin() != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourUpdateThread(GetCurrentThread()) != NO_ERROR as i32 {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_A),
+ ZludaLoadLibraryA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_W),
+ ZludaLoadLibraryW as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_A),
+ ZludaLoadLibraryExA as *mut _,
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourDetach(
+ mem::transmute(&mut LOAD_LIBRARY_EX_W),
ZludaLoadLibraryExW as *mut _,
- );
- DetourTransactionCommit();
+ ) != NO_ERROR as i32
+ {
+ return FALSE;
+ }
+ if DetourTransactionCommit() != NO_ERROR as i32 {
+ return FALSE;
+ }
}
TRUE
}
-fn get_zluda_dll_path() -> Option<(*const u16, usize)> {
+fn get_zluda_dll_path() -> Option<&'static [u16]> {
let mut module = ptr::null_mut();
loop {
module = unsafe { detours_sys::DetourEnumerateModules(module) };
@@ -79,26 +240,15 @@ fn get_zluda_dll_path() -> Option<(*const u16, usize)> {
break;
}
let mut size = 0;
- let payload = unsafe {
- detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size)
- };
+ let payload = unsafe { detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size) };
if payload != ptr::null_mut() {
- return Some((payload as *const _, (size as usize) / mem::size_of::<u16>()));
+ return unsafe {
+ Some(slice::from_raw_parts(
+ payload as *const _,
+ (size as usize) / mem::size_of::<u16>(),
+ ))
+ };
}
}
None
}
-
-unsafe fn set_zluda_dll_path(path: *const u16, len: usize) {
- let len = len as usize;
- let mut result = Vec::<u16>::with_capacity(len + ZLUDA_DLL.len() + 2);
- for i in 0..len {
- result.push(*path.add(i));
- }
- result.push(0x5c); // \
- for c in ZLUDA_DLL.iter().copied() {
- result.push(c);
- }
- result.push(0);
- ZLUDA_PATH = Some(result);
-}