aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_inject/src
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_inject/src')
-rw-r--r--zluda_inject/src/bin.rs199
-rw-r--r--zluda_inject/src/main.rs20
-rw-r--r--zluda_inject/src/win.rs21
3 files changed, 164 insertions, 76 deletions
diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs
index e021a41..af44d74 100644
--- a/zluda_inject/src/bin.rs
+++ b/zluda_inject/src/bin.rs
@@ -1,58 +1,39 @@
-extern crate clap;
-#[macro_use]
-extern crate guid;
-extern crate detours_sys;
-extern crate winapi;
-
-use std::error::Error;
-use std::ffi::OsStr;
+use std::env;
+use std::env::Args;
use std::mem;
-use std::os::windows::ffi::OsStrExt;
+use std::path::Path;
use std::ptr;
+use std::{error::Error, process};
+
+use mem::size_of_val;
+use winapi::um::{
+ jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
+ processthreadsapi::{GetExitCodeProcess, ResumeThread},
+ synchapi::WaitForSingleObject,
+ winbase::CreateJobObjectA,
+ winnt::{
+ JobObjectExtendedLimitInformation, HANDLE, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
+ JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
+ },
+};
-use winapi::um::processthreadsapi::{GetExitCodeProcess, ResumeThread};
-use winapi::um::synchapi::WaitForSingleObject;
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
-use clap::{App, AppSettings, Arg};
+static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
+static ZLUDA_DLL: &'static str = "nvcuda.dll";
-#[macro_use]
-mod win;
+include!("../../zluda_redirect/src/payload_guid.rs");
-fn main() -> Result<(), Box<dyn Error>> {
- let matches = App::new("ZLUDA injector")
- .setting(AppSettings::TrailingVarArg)
- .arg(
- Arg::with_name("EXE")
- .help("Path to the executable to be injected with ZLUDA")
- .required(true),
- )
- .arg(
- Arg::with_name("ARGS")
- .multiple(true)
- .help("Arguments that will be passed to <EXE>"),
- )
- .get_matches();
- let exe = matches.value_of_os("EXE").unwrap();
- let args: Vec<&OsStr> = matches
- .values_of_os("ARGS")
- .map(|x| x.collect())
- .unwrap_or_else(|| Vec::new());
- let mut cmd_line = Vec::<u16>::with_capacity(exe.len() + 2);
- cmd_line.push('\"' as u16);
- copy_to(exe, &mut cmd_line);
- cmd_line.push('\"' as u16);
- cmd_line.push(' ' as u16);
- args.split_last().map(|(last_arg, args)| {
- for arg in args {
- cmd_line.reserve(arg.len());
- copy_to(arg, &mut cmd_line);
- cmd_line.push(' ' as u16);
- }
- copy_to(last_arg, &mut cmd_line);
- });
-
- cmd_line.push(0);
+pub fn main_impl() -> Result<(), Box<dyn Error>> {
+ let args = env::args();
+ if args.len() == 0 {
+ print_help();
+ process::exit(1);
+ }
+ let mut cmd_line = construct_command_line(args);
+ let injector_path = env::current_exe()?;
+ let injector_dir = injector_path.parent().unwrap();
+ let redirect_path = create_redirect_path(injector_dir);
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
os_call!(
@@ -67,22 +48,19 @@ fn main() -> Result<(), Box<dyn Error>> {
ptr::null(),
&mut startup_info as *mut _,
&mut proc_info as *mut _,
- "zluda_redirect.dll\0".as_ptr() as *const i8,
+ redirect_path.as_ptr() as *const i8,
Option::None
),
|x| x != 0
);
- let mut exe_path = std::env::current_dir()?
- .as_os_str()
- .encode_wide()
- .collect::<Vec<_>>();
- let guid = guid! {"C225FC0C-00D7-40B8-935A-7E342A9344C1"};
+ kill_child_on_process_exit(proc_info.hProcess)?;
+ let mut zluda_path = create_zluda_path(injector_dir);
os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
- mem::transmute(&guid),
- exe_path.as_mut_ptr() as *mut _,
- (exe_path.len() * mem::size_of::<u16>()) as u32
+ &PAYLOAD_GUID,
+ zluda_path.as_mut_ptr() as *mut _,
+ (zluda_path.len() * mem::size_of::<u16>()) as u32
),
|x| x != 0
);
@@ -94,13 +72,110 @@ fn main() -> Result<(), Box<dyn Error>> {
GetExitCodeProcess(proc_info.hProcess, &mut child_exit_code as *mut _),
|x| x != 0
);
- std::process::exit(child_exit_code as i32)
+ process::exit(child_exit_code as i32)
}
-fn copy_to(from: &OsStr, to: &mut Vec<u16>) {
- for x in from.encode_wide() {
- to.push(x);
+fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
+ let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
+ != ptr::null_mut());
+ let mut info = unsafe { mem::zeroed::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() };
+ info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
+ os_call!(
+ SetInformationJobObject(
+ job_handle,
+ JobObjectExtendedLimitInformation,
+ &mut info as *mut _ as *mut _,
+ size_of_val(&info) as u32
+ ),
+ |x| x != 0
+ );
+ os_call!(AssignProcessToJobObject(job_handle, child), |x| x != 0);
+ Ok(())
+}
+
+fn print_help() {
+ println!(
+ "USAGE:
+ zluda <EXE> [ARGS]...
+ARGS:
+ <EXE> Path to the executable to be injected with ZLUDA
+ <ARGS>... Arguments that will be passed to <EXE>
+"
+ );
+}
+
+// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
+fn construct_command_line(args: Args) -> Vec<u16> {
+ let mut cmd_line = Vec::new();
+ let args_len = args.len();
+ for (idx, arg) in args.enumerate().skip(1) {
+ if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
+ cmd_line.extend(arg.encode_utf16());
+ } else {
+ cmd_line.push('"' as u16); // "
+ let mut char_iter = arg.chars().peekable();
+ loop {
+ let mut current = char_iter.next();
+ let mut backslashes = 0;
+ match current {
+ Some('\\') => {
+ backslashes = 1;
+ while let Some('\\') = char_iter.peek() {
+ backslashes += 1;
+ char_iter.next();
+ }
+ current = char_iter.next();
+ }
+ _ => {}
+ }
+ match current {
+ None => {
+ for _ in 0..(backslashes * 2) {
+ cmd_line.push('\\' as u16);
+ }
+ break;
+ }
+ Some('"') => {
+ for _ in 0..(backslashes * 2 + 1) {
+ cmd_line.push('\\' as u16);
+ }
+ cmd_line.push('"' as u16);
+ }
+ Some(c) => {
+ for _ in 0..backslashes {
+ cmd_line.push('\\' as u16);
+ }
+ let mut temp = [0u16; 2];
+ cmd_line.extend(&*c.encode_utf16(&mut temp));
+ }
+ }
+ }
+ cmd_line.push('"' as u16);
+ }
+ if idx < args_len - 1 {
+ cmd_line.push(' ' as u16);
+ }
}
+ cmd_line.push(0);
+ cmd_line
+}
+
+fn create_redirect_path(injector_dir: &Path) -> Vec<u8> {
+ let mut injector_dir = injector_dir.to_path_buf();
+ injector_dir.push(REDIRECT_DLL);
+ let mut result = injector_dir.to_string_lossy().into_owned().into_bytes();
+ result.push(0);
+ result
}
-//
+fn create_zluda_path(injector_dir: &Path) -> Vec<u16> {
+ let mut injector_dir = injector_dir.to_path_buf();
+ injector_dir.push(ZLUDA_DLL);
+ let mut result = injector_dir
+ .to_string_lossy()
+ .as_ref()
+ .encode_utf16()
+ .collect::<Vec<_>>();
+ result.push(0);
+ result
+}
diff --git a/zluda_inject/src/main.rs b/zluda_inject/src/main.rs
index 6e292f2..f8c1921 100644
--- a/zluda_inject/src/main.rs
+++ b/zluda_inject/src/main.rs
@@ -1,5 +1,15 @@
-#[cfg(target_os = "windows")]
-mod bin;
-
-#[cfg(not(target_os = "windows"))]
-fn main() {} \ No newline at end of file
+#[macro_use]
+#[cfg(target_os = "windows")]
+mod win;
+#[cfg(target_os = "windows")]
+mod bin;
+
+use std::error::Error;
+
+#[cfg(target_os = "windows")]
+fn main() -> Result<(), Box<dyn Error>> {
+ bin::main_impl()
+}
+
+#[cfg(not(target_os = "windows"))]
+fn main() {}
diff --git a/zluda_inject/src/win.rs b/zluda_inject/src/win.rs
index ec57ffb..4d7fcdd 100644
--- a/zluda_inject/src/win.rs
+++ b/zluda_inject/src/win.rs
@@ -48,15 +48,18 @@ macro_rules! last_ident {
macro_rules! os_call {
($($path:ident)::+ ($($args:expr),*), $success:expr) => {
- let result = unsafe{ $($path)::+ ($($args),+) };
- if !($success)(result) {
- let name = last_ident!($($path),+);
- let err_code = $crate::win::errno();
- Err($crate::win::OsError{
- function: name,
- error_code: err_code as u32,
- message: $crate::win::error_string(err_code)
- })?;
+ {
+ let result = unsafe{ $($path)::+ ($($args),*) };
+ if !($success)(result) {
+ let name = last_ident!($($path),+);
+ let err_code = $crate::win::errno();
+ Err($crate::win::OsError{
+ function: name,
+ error_code: err_code as u32,
+ message: $crate::win::error_string(err_code)
+ })?;
+ }
+ result
}
};
}