aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_inject/src/bin.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_inject/src/bin.rs')
-rw-r--r--zluda_inject/src/bin.rs271
1 files changed, 176 insertions, 95 deletions
diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs
index b49496e..408f8ab 100644
--- a/zluda_inject/src/bin.rs
+++ b/zluda_inject/src/bin.rs
@@ -1,11 +1,14 @@
+use std::env;
+use std::os::windows;
use std::os::windows::ffi::OsStrExt;
-use std::path::Path;
-use std::ptr;
-use std::{env, ops::Deref};
use std::{error::Error, process};
+use std::{fs, io, ptr};
use std::{mem, path::PathBuf};
+use argh::FromArgs;
use mem::size_of_val;
+use tempfile::TempDir;
+use winapi::um::processenv::SearchPathW;
use winapi::um::{
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
processthreadsapi::{GetExitCodeProcess, ResumeThread},
@@ -20,28 +23,46 @@ use winapi::um::{
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
-static ZLUDA_DLL: &'static str = "nvcuda.dll";
-static ZLUDA_ML_DLL: &'static str = "nvml.dll";
+static NVCUDA_DLL: &'static str = "nvcuda.dll";
+static NVML_DLL: &'static str = "nvml.dll";
include!("../../zluda_redirect/src/payload_guid.rs");
+#[derive(FromArgs)]
+/// Launch application with custom CUDA libraries
+struct ProgramArguments {
+ /// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
+ #[argh(option)]
+ nvcuda: Option<PathBuf>,
+
+ /// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
+ #[argh(option)]
+ nvml: Option<PathBuf>,
+
+ /// executable to be injected with custom CUDA libraries
+ #[argh(positional)]
+ exe: String,
+
+ /// arguments to the executable
+ #[argh(positional)]
+ args: Vec<String>,
+}
+
pub fn main_impl() -> Result<(), Box<dyn Error>> {
- let args = env::args().collect::<Vec<_>>();
- if args.len() <= 1 {
- print_help_and_exit();
- }
- let injector_path = env::current_exe()?;
- let injector_dir = injector_path.parent().unwrap();
- let redirect_path = create_redirect_path(injector_dir);
- let (mut inject_nvcuda_path, mut inject_nvml_path, cmd) =
- create_inject_path(&args[1..], injector_dir)?;
- let mut cmd_line = construct_command_line(cmd);
+ let raw_args = argh::from_env::<ProgramArguments>();
+ let normalized_args = NormalizedArguments::new(raw_args)?;
+ let mut environment = Environment::setup(normalized_args)?;
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
+ let mut dlls_to_inject = [
+ environment.nvml_path_zero_terminated.as_ptr() as *const i8,
+ environment.nvcuda_path_zero_terminated.as_ptr() as _,
+ environment.redirect_path_zero_terminated.as_ptr() as _,
+ ];
os_call!(
- detours_sys::DetourCreateProcessWithDllExW(
+ detours_sys::DetourCreateProcessWithDllsW(
ptr::null(),
- cmd_line.as_mut_ptr(),
+ environment.winapi_command_line_zero_terminated.as_mut_ptr(),
ptr::null_mut(),
ptr::null_mut(),
0,
@@ -50,7 +71,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
ptr::null(),
&mut startup_info as *mut _,
&mut proc_info as *mut _,
- redirect_path.as_ptr() as *const i8,
+ dlls_to_inject.len() as u32,
+ dlls_to_inject.as_mut_ptr(),
Option::None
),
|x| x != 0
@@ -60,8 +82,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
&PAYLOAD_NVCUDA_GUID,
- inject_nvcuda_path.as_mut_ptr() as *mut _,
- (inject_nvcuda_path.len() * mem::size_of::<u16>()) as u32
+ environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
+ environment.nvcuda_path_zero_terminated.len() as u32
),
|x| x != 0
);
@@ -69,8 +91,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
&PAYLOAD_NVML_GUID,
- inject_nvml_path.as_mut_ptr() as *mut _,
- (inject_nvml_path.len() * mem::size_of::<u16>()) as u32
+ environment.nvml_path_zero_terminated.as_ptr() as *mut _,
+ environment.nvml_path_zero_terminated.len() as u32
),
|x| x != 0
);
@@ -85,6 +107,135 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
process::exit(child_exit_code as i32)
}
+struct NormalizedArguments {
+ nvml_path: PathBuf,
+ nvcuda_path: PathBuf,
+ redirect_path: PathBuf,
+ winapi_command_line_zero_terminated: Vec<u16>,
+}
+
+impl NormalizedArguments {
+ fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
+ let current_exe = env::current_exe()?;
+ let nvml_path = Self::get_absolute_path(&current_exe, prog_args.nvml, NVML_DLL)?;
+ let nvcuda_path = Self::get_absolute_path(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?;
+ let winapi_command_line_zero_terminated =
+ construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
+ let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
+ redirect_path.push(REDIRECT_DLL);
+ Ok(Self {
+ nvml_path,
+ nvcuda_path,
+ redirect_path,
+ winapi_command_line_zero_terminated,
+ })
+ }
+
+ const WIN_MAX_PATH: usize = 260;
+
+ fn get_absolute_path(
+ current_exe: &PathBuf,
+ dll: Option<PathBuf>,
+ default: &str,
+ ) -> Result<PathBuf, Box<dyn Error>> {
+ Ok(if let Some(dll) = dll {
+ if dll.is_absolute() {
+ dll
+ } else {
+ let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
+ let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
+ dll_utf16.push(0);
+ loop {
+ let copied_len = os_call!(
+ SearchPathW(
+ ptr::null_mut(),
+ dll_utf16.as_ptr(),
+ ptr::null(),
+ full_dll_path.len() as u32,
+ full_dll_path.as_mut_ptr(),
+ ptr::null_mut()
+ ),
+ |x| x != 0
+ ) as usize;
+ if copied_len > full_dll_path.len() {
+ full_dll_path.resize(copied_len + 1, 0);
+ } else {
+ full_dll_path.truncate(copied_len);
+ break;
+ }
+ }
+ PathBuf::from(String::from_utf16_lossy(&full_dll_path))
+ }
+ } else {
+ let mut dll_path = current_exe.parent().unwrap().to_path_buf();
+ dll_path.push(default);
+ dll_path
+ })
+ }
+}
+
+struct Environment {
+ nvml_path_zero_terminated: String,
+ nvcuda_path_zero_terminated: String,
+ redirect_path_zero_terminated: String,
+ winapi_command_line_zero_terminated: Vec<u16>,
+ _temp_dir: TempDir,
+}
+
+// This structs represents "enviroment". By environment we mean all paths
+// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
+// directory which contains nvcuda.dll
+impl Environment {
+ fn setup(args: NormalizedArguments) -> io::Result<Self> {
+ let _temp_dir = TempDir::new()?;
+ let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
+ args.nvml_path,
+ &_temp_dir,
+ NVML_DLL,
+ )?);
+ let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
+ args.nvcuda_path,
+ &_temp_dir,
+ NVCUDA_DLL,
+ )?);
+ let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
+ Ok(Self {
+ nvml_path_zero_terminated,
+ nvcuda_path_zero_terminated,
+ redirect_path_zero_terminated,
+ winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
+ _temp_dir,
+ })
+ }
+
+ fn copy_to_correct_name(
+ path_buf: PathBuf,
+ temp_dir: &TempDir,
+ correct_name: &str,
+ ) -> io::Result<PathBuf> {
+ let file_name = path_buf.file_name().unwrap();
+ if file_name == correct_name {
+ Ok(path_buf)
+ } else {
+ let mut temp_file_path = temp_dir.path().to_path_buf();
+ temp_file_path.push(correct_name);
+ match windows::fs::symlink_file(&path_buf, &temp_file_path) {
+ Ok(()) => {}
+ Err(_) => {
+ fs::copy(&path_buf, &temp_file_path)?;
+ }
+ }
+ Ok(temp_file_path)
+ }
+ }
+
+ fn zero_terminate(p: PathBuf) -> String {
+ let mut s = p.to_string_lossy().to_string();
+ s.push('\0');
+ s
+ }
+}
+
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
!= ptr::null_mut());
@@ -103,29 +254,11 @@ fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
Ok(())
}
-fn print_help_and_exit() -> ! {
- let current_exe = env::current_exe().unwrap();
- let exe_name = current_exe.file_name().unwrap().to_string_lossy();
- println!(
- "USAGE:
- {0} -- <EXE> [ARGS]...
- {0} <DLL> -- <EXE> [ARGS]...
-ARGS:
- <DLL> DLL to be injected instead of system nvcuda.dll, if not provided
- will use nvcuda.dll from the directory where {0} is located
- <EXE> Path to the executable to be injected with <DLL>
- <ARGS>... Arguments that will be passed to <EXE>
-",
- exe_name
- );
- process::exit(1)
-}
-
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
-fn construct_command_line(args: &[String]) -> Vec<u16> {
+fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
let mut cmd_line = Vec::new();
- let args_len = args.len();
- for (idx, arg) in args.iter().enumerate() {
+ let args_len = args.size_hint().0;
+ for (idx, arg) in args.enumerate() {
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
cmd_line.extend(arg.encode_utf16());
} else {
@@ -176,55 +309,3 @@ fn construct_command_line(args: &[String]) -> Vec<u16> {
cmd_line.push(0);
cmd_line
}
-
-fn create_redirect_path(injector_dir: &Path) -> Vec<u8> {
- let mut injector_dir = injector_dir.to_path_buf();
- injector_dir.push(REDIRECT_DLL);
- let mut result = injector_dir.to_string_lossy().into_owned().into_bytes();
- result.push(0);
- result
-}
-
-fn create_inject_path<'a>(
- args: &'a [String],
- injector_dir: &Path,
-) -> std::io::Result<(Vec<u16>, Vec<u16>, &'a [String])> {
- let injector_dir = injector_dir.to_path_buf();
- let (nvcuda_path, unparsed_args) = if args.get(0).map(Deref::deref) == Some("--") {
- (
- encode_file_in_directory_raw(injector_dir.clone(), ZLUDA_DLL),
- &args[1..],
- )
- } else if args.get(1).map(Deref::deref) == Some("--") {
- let dll_path = make_absolute_and_encode(&args[0])?;
- (dll_path, &args[2..])
- } else {
- print_help_and_exit()
- };
- let nvml_path = encode_file_in_directory_raw(injector_dir, ZLUDA_ML_DLL);
- Ok((nvcuda_path, nvml_path, unparsed_args))
-}
-
-fn encode_file_in_directory_raw(mut dir: PathBuf, file: &'static str) -> Vec<u16> {
- dir.push(file);
- let mut result = dir
- .to_string_lossy()
- .as_ref()
- .encode_utf16()
- .collect::<Vec<_>>();
- result.push(0);
- result
-}
-
-fn make_absolute_and_encode(maybe_path: &str) -> std::io::Result<Vec<u16>> {
- let path = Path::new(maybe_path);
- let mut encoded_path = if path.is_relative() {
- let mut current_dir = env::current_dir()?;
- current_dir.push(path);
- current_dir.as_os_str().encode_wide().collect::<Vec<_>>()
- } else {
- maybe_path.encode_utf16().collect::<Vec<_>>()
- };
- encoded_path.push(0);
- Ok(encoded_path)
-}