From 2753d956df0ee3d68c3961f7b64e65df9f06bb0b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 4 Feb 2022 00:50:25 +0100 Subject: Overhaul DLL injection --- zluda_dump/src/os_win.rs | 1 - zluda_inject/Cargo.toml | 2 + zluda_inject/build.rs | 6 +- zluda_inject/src/bin.rs | 271 +++++++++------ zluda_inject/tests/helpers/do_cuinit_early.rs | 10 + zluda_inject/tests/helpers/do_cuinit_late.rs | 23 ++ zluda_inject/tests/helpers/do_cuinit_late_clr.cs | 34 ++ zluda_inject/tests/helpers/do_cuinit_late_clr.exe | Bin 0 -> 4608 bytes zluda_inject/tests/helpers/do_cuinit_main.rs | 23 -- zluda_inject/tests/helpers/do_cuinit_main_clr.cs | 34 -- zluda_inject/tests/helpers/do_cuinit_main_clr.exe | Bin 4608 -> 0 bytes zluda_inject/tests/helpers/subprocess.rs | 10 + zluda_inject/tests/inject.rs | 28 +- zluda_ml/src/impl.rs | 2 - zluda_redirect/Cargo.toml | 3 - zluda_redirect/src/lib.rs | 390 ++++------------------ 16 files changed, 351 insertions(+), 486 deletions(-) create mode 100644 zluda_inject/tests/helpers/do_cuinit_early.rs create mode 100644 zluda_inject/tests/helpers/do_cuinit_late.rs create mode 100644 zluda_inject/tests/helpers/do_cuinit_late_clr.cs create mode 100644 zluda_inject/tests/helpers/do_cuinit_late_clr.exe delete mode 100644 zluda_inject/tests/helpers/do_cuinit_main.rs delete mode 100644 zluda_inject/tests/helpers/do_cuinit_main_clr.cs delete mode 100644 zluda_inject/tests/helpers/do_cuinit_main_clr.exe create mode 100644 zluda_inject/tests/helpers/subprocess.rs diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index ef3da44..b58a0f4 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -18,7 +18,6 @@ const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedi lazy_static! { static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() }; } -include!("../../zluda_redirect/src/payload_guid.rs"); #[allow(non_snake_case)] struct PlatformLibrary { diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml index 73489bb..65113a4 100644 --- a/zluda_inject/Cargo.toml +++ b/zluda_inject/Cargo.toml @@ -10,6 +10,8 @@ path = "src/main.rs" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] } +tempfile = "3" +argh = "0.1" detours-sys = { path = "../detours-sys" } [dev-dependencies] diff --git a/zluda_inject/build.rs b/zluda_inject/build.rs index 1e425a7..ccce573 100644 --- a/zluda_inject/build.rs +++ b/zluda_inject/build.rs @@ -43,6 +43,8 @@ fn main() -> Result<(), VarError> { .arg("-ldylib=nvcuda") .arg("-C") .arg(format!("opt-level={}", opt_level)) + .arg("-L") + .arg(format!("{}", out_dir)) .arg("--out-dir") .arg(format!("{}", out_dir)) .arg("--target") @@ -52,11 +54,11 @@ fn main() -> Result<(), VarError> { } std::fs::copy( format!( - "{}{}do_cuinit_main_clr.exe", + "{}{}do_cuinit_late_clr.exe", helpers_dir_as_string, path::MAIN_SEPARATOR ), - format!("{}{}do_cuinit_main_clr.exe", out_dir, path::MAIN_SEPARATOR), + format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR), ) .unwrap(); println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir); diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs index b49496e..408f8ab 100644 --- a/zluda_inject/src/bin.rs +++ b/zluda_inject/src/bin.rs @@ -1,11 +1,14 @@ +use std::env; +use std::os::windows; use std::os::windows::ffi::OsStrExt; -use std::path::Path; -use std::ptr; -use std::{env, ops::Deref}; use std::{error::Error, process}; +use std::{fs, io, ptr}; use std::{mem, path::PathBuf}; +use argh::FromArgs; use mem::size_of_val; +use tempfile::TempDir; +use winapi::um::processenv::SearchPathW; use winapi::um::{ jobapi2::{AssignProcessToJobObject, SetInformationJobObject}, processthreadsapi::{GetExitCodeProcess, ResumeThread}, @@ -20,28 +23,46 @@ use winapi::um::{ use winapi::um::winbase::{INFINITE, WAIT_FAILED}; static REDIRECT_DLL: &'static str = "zluda_redirect.dll"; -static ZLUDA_DLL: &'static str = "nvcuda.dll"; -static ZLUDA_ML_DLL: &'static str = "nvml.dll"; +static NVCUDA_DLL: &'static str = "nvcuda.dll"; +static NVML_DLL: &'static str = "nvml.dll"; include!("../../zluda_redirect/src/payload_guid.rs"); +#[derive(FromArgs)] +/// Launch application with custom CUDA libraries +struct ProgramArguments { + /// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory + #[argh(option)] + nvcuda: Option, + + /// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory + #[argh(option)] + nvml: Option, + + /// executable to be injected with custom CUDA libraries + #[argh(positional)] + exe: String, + + /// arguments to the executable + #[argh(positional)] + args: Vec, +} + pub fn main_impl() -> Result<(), Box> { - let args = env::args().collect::>(); - if args.len() <= 1 { - print_help_and_exit(); - } - let injector_path = env::current_exe()?; - let injector_dir = injector_path.parent().unwrap(); - let redirect_path = create_redirect_path(injector_dir); - let (mut inject_nvcuda_path, mut inject_nvml_path, cmd) = - create_inject_path(&args[1..], injector_dir)?; - let mut cmd_line = construct_command_line(cmd); + let raw_args = argh::from_env::(); + let normalized_args = NormalizedArguments::new(raw_args)?; + let mut environment = Environment::setup(normalized_args)?; let mut startup_info = unsafe { mem::zeroed::() }; let mut proc_info = unsafe { mem::zeroed::() }; + let mut dlls_to_inject = [ + environment.nvml_path_zero_terminated.as_ptr() as *const i8, + environment.nvcuda_path_zero_terminated.as_ptr() as _, + environment.redirect_path_zero_terminated.as_ptr() as _, + ]; os_call!( - detours_sys::DetourCreateProcessWithDllExW( + detours_sys::DetourCreateProcessWithDllsW( ptr::null(), - cmd_line.as_mut_ptr(), + environment.winapi_command_line_zero_terminated.as_mut_ptr(), ptr::null_mut(), ptr::null_mut(), 0, @@ -50,7 +71,8 @@ pub fn main_impl() -> Result<(), Box> { ptr::null(), &mut startup_info as *mut _, &mut proc_info as *mut _, - redirect_path.as_ptr() as *const i8, + dlls_to_inject.len() as u32, + dlls_to_inject.as_mut_ptr(), Option::None ), |x| x != 0 @@ -60,8 +82,8 @@ pub fn main_impl() -> Result<(), Box> { detours_sys::DetourCopyPayloadToProcess( proc_info.hProcess, &PAYLOAD_NVCUDA_GUID, - inject_nvcuda_path.as_mut_ptr() as *mut _, - (inject_nvcuda_path.len() * mem::size_of::()) as u32 + environment.nvcuda_path_zero_terminated.as_ptr() as *mut _, + environment.nvcuda_path_zero_terminated.len() as u32 ), |x| x != 0 ); @@ -69,8 +91,8 @@ pub fn main_impl() -> Result<(), Box> { detours_sys::DetourCopyPayloadToProcess( proc_info.hProcess, &PAYLOAD_NVML_GUID, - inject_nvml_path.as_mut_ptr() as *mut _, - (inject_nvml_path.len() * mem::size_of::()) as u32 + environment.nvml_path_zero_terminated.as_ptr() as *mut _, + environment.nvml_path_zero_terminated.len() as u32 ), |x| x != 0 ); @@ -85,6 +107,135 @@ pub fn main_impl() -> Result<(), Box> { process::exit(child_exit_code as i32) } +struct NormalizedArguments { + nvml_path: PathBuf, + nvcuda_path: PathBuf, + redirect_path: PathBuf, + winapi_command_line_zero_terminated: Vec, +} + +impl NormalizedArguments { + fn new(prog_args: ProgramArguments) -> Result> { + let current_exe = env::current_exe()?; + let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?; + let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?; + let winapi_command_line_zero_terminated = + construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args)); + let mut redirect_path = current_exe.parent().unwrap().to_path_buf(); + redirect_path.push(REDIRECT_DLL); + Ok(Self { + nvml_path, + nvcuda_path, + redirect_path, + winapi_command_line_zero_terminated, + }) + } + + const WIN_MAX_PATH: usize = 260; + + fn get_absolute_path( + current_exe: &PathBuf, + dll: Option, + default: &str, + ) -> Result> { + Ok(if let Some(dll) = dll { + if dll.is_absolute() { + dll + } else { + let mut full_dll_path = vec![0; Self::WIN_MAX_PATH]; + let mut dll_utf16 = dll.as_os_str().encode_wide().collect::>(); + dll_utf16.push(0); + loop { + let copied_len = os_call!( + SearchPathW( + ptr::null_mut(), + dll_utf16.as_ptr(), + ptr::null(), + full_dll_path.len() as u32, + full_dll_path.as_mut_ptr(), + ptr::null_mut() + ), + |x| x != 0 + ) as usize; + if copied_len > full_dll_path.len() { + full_dll_path.resize(copied_len + 1, 0); + } else { + full_dll_path.truncate(copied_len); + break; + } + } + PathBuf::from(String::from_utf16_lossy(&full_dll_path)) + } + } else { + let mut dll_path = current_exe.parent().unwrap().to_path_buf(); + dll_path.push(default); + dll_path + }) + } +} + +struct Environment { + nvml_path_zero_terminated: String, + nvcuda_path_zero_terminated: String, + redirect_path_zero_terminated: String, + winapi_command_line_zero_terminated: Vec, + _temp_dir: TempDir, +} + +// This structs represents "enviroment". By environment we mean all paths +// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary +// directory which contains nvcuda.dll +impl Environment { + fn setup(args: NormalizedArguments) -> io::Result { + let _temp_dir = TempDir::new()?; + let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.nvml_path, + &_temp_dir, + NVML_DLL, + )?); + let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.nvcuda_path, + &_temp_dir, + NVCUDA_DLL, + )?); + let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path); + Ok(Self { + nvml_path_zero_terminated, + nvcuda_path_zero_terminated, + redirect_path_zero_terminated, + winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated, + _temp_dir, + }) + } + + fn copy_to_correct_name( + path_buf: PathBuf, + temp_dir: &TempDir, + correct_name: &str, + ) -> io::Result { + let file_name = path_buf.file_name().unwrap(); + if file_name == correct_name { + Ok(path_buf) + } else { + let mut temp_file_path = temp_dir.path().to_path_buf(); + temp_file_path.push(correct_name); + match windows::fs::symlink_file(&path_buf, &temp_file_path) { + Ok(()) => {} + Err(_) => { + fs::copy(&path_buf, &temp_file_path)?; + } + } + Ok(temp_file_path) + } + } + + fn zero_terminate(p: PathBuf) -> String { + let mut s = p.to_string_lossy().to_string(); + s.push('\0'); + s + } +} + fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box> { let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x != ptr::null_mut()); @@ -103,29 +254,11 @@ fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box> { Ok(()) } -fn print_help_and_exit() -> ! { - let current_exe = env::current_exe().unwrap(); - let exe_name = current_exe.file_name().unwrap().to_string_lossy(); - println!( - "USAGE: - {0} -- [ARGS]... - {0} -- [ARGS]... -ARGS: - DLL to be injected instead of system nvcuda.dll, if not provided - will use nvcuda.dll from the directory where {0} is located - Path to the executable to be injected with - ... Arguments that will be passed to -", - exe_name - ); - process::exit(1) -} - // Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way -fn construct_command_line(args: &[String]) -> Vec { +fn construct_command_line(args: impl Iterator) -> Vec { let mut cmd_line = Vec::new(); - let args_len = args.len(); - for (idx, arg) in args.iter().enumerate() { + let args_len = args.size_hint().0; + for (idx, arg) in args.enumerate() { if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) { cmd_line.extend(arg.encode_utf16()); } else { @@ -176,55 +309,3 @@ fn construct_command_line(args: &[String]) -> Vec { cmd_line.push(0); cmd_line } - -fn create_redirect_path(injector_dir: &Path) -> Vec { - let mut injector_dir = injector_dir.to_path_buf(); - injector_dir.push(REDIRECT_DLL); - let mut result = injector_dir.to_string_lossy().into_owned().into_bytes(); - result.push(0); - result -} - -fn create_inject_path<'a>( - args: &'a [String], - injector_dir: &Path, -) -> std::io::Result<(Vec, Vec, &'a [String])> { - let injector_dir = injector_dir.to_path_buf(); - let (nvcuda_path, unparsed_args) = if args.get(0).map(Deref::deref) == Some("--") { - ( - encode_file_in_directory_raw(injector_dir.clone(), ZLUDA_DLL), - &args[1..], - ) - } else if args.get(1).map(Deref::deref) == Some("--") { - let dll_path = make_absolute_and_encode(&args[0])?; - (dll_path, &args[2..]) - } else { - print_help_and_exit() - }; - let nvml_path = encode_file_in_directory_raw(injector_dir, ZLUDA_ML_DLL); - Ok((nvcuda_path, nvml_path, unparsed_args)) -} - -fn encode_file_in_directory_raw(mut dir: PathBuf, file: &'static str) -> Vec { - dir.push(file); - let mut result = dir - .to_string_lossy() - .as_ref() - .encode_utf16() - .collect::>(); - result.push(0); - result -} - -fn make_absolute_and_encode(maybe_path: &str) -> std::io::Result> { - let path = Path::new(maybe_path); - let mut encoded_path = if path.is_relative() { - let mut current_dir = env::current_dir()?; - current_dir.push(path); - current_dir.as_os_str().encode_wide().collect::>() - } else { - maybe_path.encode_utf16().collect::>() - }; - encoded_path.push(0); - Ok(encoded_path) -} diff --git a/zluda_inject/tests/helpers/do_cuinit_early.rs b/zluda_inject/tests/helpers/do_cuinit_early.rs new file mode 100644 index 0000000..9743f4a --- /dev/null +++ b/zluda_inject/tests/helpers/do_cuinit_early.rs @@ -0,0 +1,10 @@ +#![crate_type = "bin"] + +#[link(name = "do_cuinit")] +extern "system" { + fn do_cuinit(flags: u32) -> u32; +} + +fn main() { + unsafe { do_cuinit(0) }; +} diff --git a/zluda_inject/tests/helpers/do_cuinit_late.rs b/zluda_inject/tests/helpers/do_cuinit_late.rs new file mode 100644 index 0000000..ab3516d --- /dev/null +++ b/zluda_inject/tests/helpers/do_cuinit_late.rs @@ -0,0 +1,23 @@ +#![crate_type = "bin"] + +use std::ffi::c_void; +use std::mem; +use std::env; +use std::path::PathBuf; +use std::ffi::CString; + +extern "system" { + fn LoadLibraryA(lpFileName: *const i8) -> *mut c_void; + fn GetProcAddress(hModule: *mut c_void, lpProcName: *const u8) -> *mut c_void; +} + +fn main() { + let current_exe = env::current_exe().unwrap(); + let mut dll = PathBuf::from(current_exe.parent().unwrap()); + dll.push("do_cuinit.dll"); + let dll_cstring = CString::new(dll.to_str().unwrap()).unwrap(); + let nvcuda = unsafe { LoadLibraryA(dll_cstring.as_ptr()) }; + let cu_init = unsafe { GetProcAddress(nvcuda, b"do_cuinit\0".as_ptr()) }; + let cu_init = unsafe { mem::transmute::<_, unsafe extern "system" fn(u32) -> u32>(cu_init) }; + unsafe { cu_init(0) }; +} diff --git a/zluda_inject/tests/helpers/do_cuinit_late_clr.cs b/zluda_inject/tests/helpers/do_cuinit_late_clr.cs new file mode 100644 index 0000000..666c237 --- /dev/null +++ b/zluda_inject/tests/helpers/do_cuinit_late_clr.cs @@ -0,0 +1,34 @@ +using System; +using System.IO; +using System.Reflection; +using System.Runtime.InteropServices; + +namespace Zluda +{ + class Program + { + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + private delegate int CuInit(int flags); + + static int Main(string[] args) + { + DirectoryInfo exeDirectory = Directory.GetParent(Assembly.GetEntryAssembly().Location); + string dllPath = Path.Combine(exeDirectory.ToString(), "do_cuinit.dll"); + IntPtr nvcuda = NativeMethods.LoadLibrary(dllPath); + if (nvcuda == IntPtr.Zero) + return 1; + IntPtr doCuinitPtr = NativeMethods.GetProcAddress(nvcuda, "do_cuinit"); + CuInit cuinit = (CuInit)Marshal.GetDelegateForFunctionPointer(doCuinitPtr, typeof(CuInit)); + return cuinit(0); + } + } + + static class NativeMethods + { + [DllImport("kernel32.dll")] + public static extern IntPtr LoadLibrary(string dllToLoad); + + [DllImport("kernel32.dll")] + public static extern IntPtr GetProcAddress(IntPtr hModule, string procedureName); + } +} \ No newline at end of file diff --git a/zluda_inject/tests/helpers/do_cuinit_late_clr.exe b/zluda_inject/tests/helpers/do_cuinit_late_clr.exe new file mode 100644 index 0000000..b8e4975 Binary files /dev/null and b/zluda_inject/tests/helpers/do_cuinit_late_clr.exe differ diff --git a/zluda_inject/tests/helpers/do_cuinit_main.rs b/zluda_inject/tests/helpers/do_cuinit_main.rs deleted file mode 100644 index ab3516d..0000000 --- a/zluda_inject/tests/helpers/do_cuinit_main.rs +++ /dev/null @@ -1,23 +0,0 @@ -#![crate_type = "bin"] - -use std::ffi::c_void; -use std::mem; -use std::env; -use std::path::PathBuf; -use std::ffi::CString; - -extern "system" { - fn LoadLibraryA(lpFileName: *const i8) -> *mut c_void; - fn GetProcAddress(hModule: *mut c_void, lpProcName: *const u8) -> *mut c_void; -} - -fn main() { - let current_exe = env::current_exe().unwrap(); - let mut dll = PathBuf::from(current_exe.parent().unwrap()); - dll.push("do_cuinit.dll"); - let dll_cstring = CString::new(dll.to_str().unwrap()).unwrap(); - let nvcuda = unsafe { LoadLibraryA(dll_cstring.as_ptr()) }; - let cu_init = unsafe { GetProcAddress(nvcuda, b"do_cuinit\0".as_ptr()) }; - let cu_init = unsafe { mem::transmute::<_, unsafe extern "system" fn(u32) -> u32>(cu_init) }; - unsafe { cu_init(0) }; -} diff --git a/zluda_inject/tests/helpers/do_cuinit_main_clr.cs b/zluda_inject/tests/helpers/do_cuinit_main_clr.cs deleted file mode 100644 index 666c237..0000000 --- a/zluda_inject/tests/helpers/do_cuinit_main_clr.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System; -using System.IO; -using System.Reflection; -using System.Runtime.InteropServices; - -namespace Zluda -{ - class Program - { - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - private delegate int CuInit(int flags); - - static int Main(string[] args) - { - DirectoryInfo exeDirectory = Directory.GetParent(Assembly.GetEntryAssembly().Location); - string dllPath = Path.Combine(exeDirectory.ToString(), "do_cuinit.dll"); - IntPtr nvcuda = NativeMethods.LoadLibrary(dllPath); - if (nvcuda == IntPtr.Zero) - return 1; - IntPtr doCuinitPtr = NativeMethods.GetProcAddress(nvcuda, "do_cuinit"); - CuInit cuinit = (CuInit)Marshal.GetDelegateForFunctionPointer(doCuinitPtr, typeof(CuInit)); - return cuinit(0); - } - } - - static class NativeMethods - { - [DllImport("kernel32.dll")] - public static extern IntPtr LoadLibrary(string dllToLoad); - - [DllImport("kernel32.dll")] - public static extern IntPtr GetProcAddress(IntPtr hModule, string procedureName); - } -} \ No newline at end of file diff --git a/zluda_inject/tests/helpers/do_cuinit_main_clr.exe b/zluda_inject/tests/helpers/do_cuinit_main_clr.exe deleted file mode 100644 index b8e4975..0000000 Binary files a/zluda_inject/tests/helpers/do_cuinit_main_clr.exe and /dev/null differ diff --git a/zluda_inject/tests/helpers/subprocess.rs b/zluda_inject/tests/helpers/subprocess.rs new file mode 100644 index 0000000..3d1588c --- /dev/null +++ b/zluda_inject/tests/helpers/subprocess.rs @@ -0,0 +1,10 @@ +#![crate_type = "bin"] + +use std::io; +use std::process::Command; + +fn main() -> io::Result<()> { + let status = Command::new("direct_cuinit.exe").status()?; + assert!(status.success()); + Ok(()) +} diff --git a/zluda_inject/tests/inject.rs b/zluda_inject/tests/inject.rs index 15e5e04..3e6ae97 100644 --- a/zluda_inject/tests/inject.rs +++ b/zluda_inject/tests/inject.rs @@ -6,18 +6,28 @@ fn direct_cuinit() -> io::Result<()> { } #[test] -fn indirect_cuinit() -> io::Result<()> { - run_process_and_check_for_zluda_dump("indirect_cuinit") +fn do_cuinit_early() -> io::Result<()> { + run_process_and_check_for_zluda_dump("do_cuinit_early") +} + +#[test] +fn do_cuinit_late() -> io::Result<()> { + run_process_and_check_for_zluda_dump("do_cuinit_late") } #[test] -fn do_cuinit() -> io::Result<()> { - run_process_and_check_for_zluda_dump("do_cuinit_main") +fn do_cuinit_late_clr() -> io::Result<()> { + run_process_and_check_for_zluda_dump("do_cuinit_late_clr") +} + +#[test] +fn indirect_cuinit() -> io::Result<()> { + run_process_and_check_for_zluda_dump("indirect_cuinit") } #[test] -fn do_cuinit_clr() -> io::Result<()> { - run_process_and_check_for_zluda_dump("do_cuinit_main_clr") +fn subprocess() -> io::Result<()> { + run_process_and_check_for_zluda_dump("subprocess") } fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> { @@ -27,7 +37,11 @@ fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> { let helpers_dir = env!("HELPERS_OUT_DIR"); let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name); let mut test_cmd = Command::new(&zluda_with_exe); - let test_cmd = test_cmd.arg(&zluda_dump_dll).arg("--").arg(&exe_under_test); + let test_cmd = test_cmd + .arg("--nvcuda") + .arg(&zluda_dump_dll) + .arg("--") + .arg(&exe_under_test); let test_output = test_cmd.output()?; assert!(test_output.status.success()); let stderr_text = String::from_utf8(test_output.stderr).unwrap(); diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs index 48141bd..2f82008 100644 --- a/zluda_ml/src/impl.rs +++ b/zluda_ml/src/impl.rs @@ -1,5 +1,3 @@ -use std::io::Write; -use std::slice; use std::{ os::raw::{c_char, c_uint}, ptr, diff --git a/zluda_redirect/Cargo.toml b/zluda_redirect/Cargo.toml index 193732c..2a5c3b1 100644 --- a/zluda_redirect/Cargo.toml +++ b/zluda_redirect/Cargo.toml @@ -11,6 +11,3 @@ crate-type = ["cdylib"] detours-sys = { path = "../detours-sys" } wchar = "0.6" winapi = { version = "0.3", features = [ "sysinfoapi", "memoryapi", "processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] } -tempfile = "3" -goblin = { version = "0.4", default-features = false, features = ["pe64"] } -memoffset = "0.6" \ No newline at end of file diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index 0f73ced..522d705 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -6,32 +6,35 @@ extern crate winapi; use std::{ collections::HashMap, ffi::{c_void, CStr}, - io, mem, - os::raw::c_uint, - ptr, slice, usize, + mem, ptr, slice, usize, }; use detours_sys::{ - DetourAllocateRegionWithinJumpBounds, DetourAttach, DetourEnumerateExports, - DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin, + DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin, DetourTransactionCommit, DetourUpdateProcessWithDll, DetourUpdateThread, }; -use goblin::pe::{ - self, - header::{CoffHeader, DOS_MAGIC, PE_MAGIC, PE_POINTER_OFFSET}, - optional_header::StandardFields64, -}; -use memoffset::offset_of; -use tempfile::TempDir; use wchar::wch; +use winapi::{ + shared::minwindef::{BOOL, LPVOID}, + um::{ + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + minwinbase::LPSECURITY_ATTRIBUTES, + processthreadsapi::{ + CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread, + SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW, + }, + tlhelp32::{ + CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32, + }, + winbase::CREATE_SUSPENDED, + winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME}, + }, +}; use winapi::{ shared::minwindef::{DWORD, FALSE, HMODULE, TRUE}, um::{ libloaderapi::{GetModuleHandleA, LoadLibraryExA}, - memoryapi::VirtualProtect, - processthreadsapi::{FlushInstructionCache, GetCurrentProcess}, - sysinfoapi::GetSystemInfo, - winnt::{LPCSTR, PAGE_READWRITE}, + winnt::LPCSTR, }, }; use winapi::{ @@ -47,26 +50,6 @@ use winapi::{ shared::winerror::NO_ERROR, um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW}, }; -use winapi::{ - shared::{ - minwindef::{BOOL, LPVOID}, - winerror::E_UNEXPECTED, - }, - um::{ - handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, - libloaderapi::GetModuleHandleW, - minwinbase::LPSECURITY_ATTRIBUTES, - processthreadsapi::{ - CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread, - SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW, - }, - tlhelp32::{ - CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32, - }, - winbase::{CopyFileW, CreateSymbolicLinkW, CREATE_SUSPENDED}, - winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME}, - }, -}; include!("payload_guid.rs"); @@ -74,13 +57,12 @@ const NVCUDA_UTF8: &'static str = "NVCUDA.DLL"; const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL"); const NVML_UTF8: &'static str = "NVML.DLL"; const NVML_UTF16: &[u16] = wch!("NVML.DLL"); -static mut ZLUDA_PATH_UTF8: Vec = Vec::new(); -static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None; -static mut ZLUDA_ML_PATH_UTF8: Vec = Vec::new(); -static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None; +static mut ZLUDA_PATH_UTF8: Option<&'static [u8]> = None; +static mut ZLUDA_PATH_UTF16: Vec = Vec::new(); +static mut ZLUDA_ML_PATH_UTF8: Option<&'static [u8]> = None; +static mut ZLUDA_ML_PATH_UTF16: Vec = Vec::new(); static mut CURRENT_MODULE_FILENAME: Vec = Vec::new(); static mut DETOUR_STATE: Option = None; -const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801; #[no_mangle] #[used] @@ -197,9 +179,9 @@ unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) - #[allow(non_snake_case)] unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE { let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) { - ZLUDA_PATH_UTF8.as_ptr() as *const _ + ZLUDA_PATH_UTF8.unwrap().as_ptr() as *const _ } else if is_nvml_dll_utf8(lpLibFileName as *const _) { - ZLUDA_ML_PATH_UTF8.as_ptr() as *const _ + ZLUDA_ML_PATH_UTF8.unwrap().as_ptr() as *const _ } else { lpLibFileName }; @@ -209,9 +191,9 @@ unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE { #[allow(non_snake_case)] unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE { let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) { - ZLUDA_PATH_UTF16.unwrap().as_ptr() + ZLUDA_PATH_UTF16.as_ptr() } else if is_nvml_dll_utf16(lpLibFileName as *const _) { - ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() + ZLUDA_ML_PATH_UTF16.as_ptr() } else { lpLibFileName }; @@ -225,9 +207,9 @@ unsafe extern "system" fn ZludaLoadLibraryExA( dwFlags: DWORD, ) -> HMODULE { let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) { - ZLUDA_PATH_UTF8.as_ptr() as *const _ + ZLUDA_PATH_UTF8.unwrap().as_ptr() as *const _ } else if is_nvml_dll_utf8(lpLibFileName as *const _) { - ZLUDA_ML_PATH_UTF8.as_ptr() as *const _ + ZLUDA_ML_PATH_UTF8.unwrap().as_ptr() as *const _ } else { lpLibFileName }; @@ -241,9 +223,9 @@ unsafe extern "system" fn ZludaLoadLibraryExW( dwFlags: DWORD, ) -> HMODULE { let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) { - ZLUDA_PATH_UTF16.unwrap().as_ptr() + ZLUDA_PATH_UTF16.as_ptr() } else if is_nvml_dll_utf16(lpLibFileName as *const _) { - ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() + ZLUDA_ML_PATH_UTF16.as_ptr() } else { lpLibFileName }; @@ -392,67 +374,6 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW( continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation) } -static mut MAIN: unsafe extern "system" fn() -> DWORD = zluda_main; -static mut COR_EXE_MAIN: unsafe extern "system" fn() -> DWORD = zluda_main_clr; - -// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications -// "If a DLL with the same module name is already loaded in memory, the system -// uses the loaded DLL, no matter which directory it is in. The system does not -// search for the DLL." -unsafe extern "system" fn zluda_main() -> DWORD { - zluda_main_impl(MAIN) -} - -unsafe extern "system" fn zluda_main_clr() -> DWORD { - zluda_main_impl(COR_EXE_MAIN) -} - -unsafe fn zluda_main_impl(original: unsafe extern "system" fn() -> DWORD) -> DWORD { - let temp_dir = match do_zluda_preload() { - Ok(f) => f, - Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as DWORD, - }; - let result = original(); - drop(temp_dir); - result -} - -unsafe fn do_zluda_preload() -> std::io::Result { - let temp_dir = tempfile::tempdir()?; - do_single_zluda_preload(&temp_dir, ZLUDA_PATH_UTF16.unwrap().as_ptr(), NVCUDA_UTF8)?; - do_single_zluda_preload(&temp_dir, ZLUDA_ML_PATH_UTF16.unwrap().as_ptr(), NVML_UTF8)?; - Ok(temp_dir) -} - -unsafe fn do_single_zluda_preload( - temp_dir: &TempDir, - full_path: *const u16, - file_name: &'static str, -) -> io::Result<()> { - let mut temp_file_path = temp_dir.path().to_path_buf(); - temp_file_path.push(file_name); - let mut temp_file_path_utf16 = temp_file_path - .into_os_string() - .to_string_lossy() - .encode_utf16() - .collect::>(); - temp_file_path_utf16.push(0); - // Probably we are not in developer mode, do a copty then - if 0 == CreateSymbolicLinkW( - temp_file_path_utf16.as_ptr(), - full_path, - 0x2, //SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE - ) { - if 0 == CopyFileW(full_path, temp_file_path_utf16.as_ptr(), 1) { - return Err(io::Error::last_os_error()); - } - } - if ptr::null_mut() == ZludaLoadLibraryW_NoRedirect(temp_file_path_utf16.as_ptr()) { - return Err(io::Error::last_os_error()); - } - Ok(()) -} - // This type encapsulates typical calling sequence of detours and cleanup. // We have two ways we do detours: // * If we are loaded before nvcuda.dll, we hook LoadLibrary* @@ -633,21 +554,31 @@ unsafe fn continue_create_process_hook( // continues uninjected than to break the parent if DetourUpdateProcessWithDll( (*process_information).hProcess, - &mut CURRENT_MODULE_FILENAME.as_ptr() as *mut _ as *mut _, + &mut ZLUDA_ML_PATH_UTF8.unwrap().as_ptr() as *mut _ as *mut _, 1, ) != FALSE + && DetourUpdateProcessWithDll( + (*process_information).hProcess, + &mut ZLUDA_PATH_UTF8.unwrap().as_ptr() as *mut _ as *mut _, + 1, + ) != FALSE + && DetourUpdateProcessWithDll( + (*process_information).hProcess, + &mut CURRENT_MODULE_FILENAME.as_ptr() as *mut _ as *mut _, + 1, + ) != FALSE { detours_sys::DetourCopyPayloadToProcess( (*process_information).hProcess, &PAYLOAD_NVML_GUID, - ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() as *mut _, - (ZLUDA_ML_PATH_UTF16.unwrap().len() * mem::size_of::()) as u32, + ZLUDA_ML_PATH_UTF16.as_ptr() as *mut _, + (ZLUDA_ML_PATH_UTF16.len() * mem::size_of::()) as u32, ); detours_sys::DetourCopyPayloadToProcess( (*process_information).hProcess, &PAYLOAD_NVCUDA_GUID, - ZLUDA_PATH_UTF16.unwrap().as_ptr() as *mut _, - (ZLUDA_PATH_UTF16.unwrap().len() * mem::size_of::()) as u32, + ZLUDA_PATH_UTF16.as_ptr() as *mut _, + (ZLUDA_PATH_UTF16.len() * mem::size_of::()) as u32, ); } if original_creation_flags & CREATE_SUSPENDED == 0 { @@ -733,23 +664,18 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u } match get_zluda_dlls_paths() { Some((nvcuda_path, nvml_path)) => { - ZLUDA_PATH_UTF16 = Some(nvcuda_path); - ZLUDA_ML_PATH_UTF16 = Some(nvml_path); - // from_utf16_lossy(...) handles terminating NULL correctly - ZLUDA_PATH_UTF8 = String::from_utf16_lossy(nvcuda_path).into_bytes(); - ZLUDA_ML_PATH_UTF8 = String::from_utf16_lossy(nvml_path).into_bytes(); + ZLUDA_PATH_UTF8 = Some(nvcuda_path); + ZLUDA_ML_PATH_UTF8 = Some(nvml_path); + ZLUDA_PATH_UTF16 = std::str::from_utf8_unchecked(nvcuda_path) + .encode_utf16() + .collect::>(); + ZLUDA_ML_PATH_UTF16 = std::str::from_utf8_unchecked(nvml_path) + .encode_utf16() + .collect::>(); } None => return FALSE, } - // If the application (directly or not) links to nvcuda.dll, nvcuda.dll - // will get loaded before we can act. In this case, instead of - // redirecting LoadLibrary* to load ZLUDA, we override already loaded - // functions - let detach_guard = match get_cuinit() { - Some((nvcuda_mod, _)) => detour_already_loaded_nvcuda(nvcuda_mod), - None => detour_main(), - }; - match detach_guard { + match detour_already_loaded_nvcuda() { Some(g) => { DETOUR_STATE = Some(g); TRUE @@ -787,42 +713,9 @@ unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool { } } -unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> { - let nvcuda = GetModuleHandleA(b"nvcuda\0".as_ptr() as _); - if nvcuda == ptr::null_mut() { - return None; - } - let cuinit_addr = GetProcAddress(nvcuda, b"cuInit\0".as_ptr() as _); - if cuinit_addr == ptr::null_mut() { - return None; - } - Some((nvcuda as *mut _, cuinit_addr)) -} - #[must_use] -unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option { - let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr()); - if zluda_module == ptr::null_mut() { - return None; - } - let original_functions = gather_imports(nvcuda_mod); - let override_functions = gather_imports(zluda_module); - let mut override_fn_pairs = HashMap::with_capacity(original_functions.len()); - // TODO: optimize - for (original_fn_name, original_fn_address) in original_functions { - let override_fn_address = - match override_functions.binary_search_by_key(&original_fn_name, |(name, _)| *name) { - Ok(x) => override_functions[x].1, - Err(_) => { - // TODO: print a warning in debug - cuda_unsupported as _ - } - }; - override_fn_pairs.insert( - original_fn_name, - (original_fn_address as _, override_fn_address), - ); - } +unsafe fn detour_already_loaded_nvcuda() -> Option { + let nvcuda_mod = GetModuleHandleA(b"nvcuda\0".as_ptr() as _); let detour_functions = vec![ ( &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, @@ -838,146 +731,10 @@ unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option c_uint { - CUDA_ERROR_NOT_SUPPORTED -} - -unsafe fn gather_imports(module: HINSTANCE) -> Vec<(&'static CStr, *mut c_void)> { - let mut result = Vec::new(); - DetourEnumerateExports( - module as _, - &mut result as *mut _ as *mut _, - Some(gather_imports_impl), - ); - result -} - -unsafe extern "stdcall" fn gather_imports_impl( - context: *mut c_void, - _: u32, - name: LPCSTR, - code: *mut c_void, -) -> i32 { - let result: &mut Vec<(&'static CStr, *mut c_void)> = &mut *(context as *mut Vec<_>); - result.push((CStr::from_ptr(name), code)); - TRUE -} - -#[must_use] -unsafe fn detour_main() -> Option { - if !override_entry_point() { - return None; - } - let mut detour_functions = vec![ - ( - &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, - ZludaLoadLibraryA as *mut c_void, - ), - (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _), - ( - &mut LOAD_LIBRARY_EX_A as *mut _ as _, - ZludaLoadLibraryExA as _, - ), - ( - &mut LOAD_LIBRARY_EX_W as *mut _ as _, - ZludaLoadLibraryExW as _, - ), - ]; - detour_functions.extend(get_clr_entry_point()); - DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new()) -} - -unsafe fn override_entry_point() -> bool { - let exe_handle = GetModuleHandleW(ptr::null()); - let dos_signature = exe_handle as *mut u16; - if *dos_signature != DOS_MAGIC { - return false; - } - let pe_offset = *((exe_handle as *mut u8).add(PE_POINTER_OFFSET as usize) as *mut u32); - let pe_sig = (exe_handle as *mut u8).add(pe_offset as usize) as *mut u32; - if (*pe_sig) != PE_MAGIC { - return false; - } - let coff_header = pe_sig.add(1) as *mut CoffHeader; - let standard_coff_fields = coff_header.add(1) as *mut StandardFields64; - if (*standard_coff_fields).magic != pe::optional_header::MAGIC_64 { - return false; - } - let entry_point = mem::transmute::<_, unsafe extern "system" fn() -> DWORD>( - (exe_handle as *mut u8).add((*standard_coff_fields).address_of_entry_point as usize), - ); - let mut allocated_size = 0; - let exe_region = DetourAllocateRegionWithinJumpBounds(exe_handle as _, &mut allocated_size); - if (allocated_size as usize) < mem::size_of::() || exe_region == ptr::null_mut() { - return false; - } - MAIN = entry_point; - *(exe_region as *mut JmpThunk64) = JmpThunk64::new(zluda_main); - FlushInstructionCache( - GetCurrentProcess(), - exe_region, - mem::size_of::(), - ); - let new_address_of_entry_point = (exe_region as *mut u8).offset_from(exe_handle as *mut u8); - let entry_point_offset = offset_of!(StandardFields64, address_of_entry_point); - let mut system_info = mem::zeroed(); - GetSystemInfo(&mut system_info); - let pointer_to_address_of_entry_point = - (standard_coff_fields as *mut u8).add(entry_point_offset) as *mut i32; - let page_size = system_info.dwPageSize as usize; - let page_start = (((pointer_to_address_of_entry_point as usize) / page_size) * page_size) as _; - let mut old_protect = 0; - if VirtualProtect(page_start, page_size, PAGE_READWRITE, &mut old_protect) == 0 { - return false; - } - *pointer_to_address_of_entry_point = new_address_of_entry_point as i32; - if VirtualProtect(page_start, page_size, old_protect, &mut old_protect) == 0 { - return false; - } - true -} - -// mov rax, $address; -// jmp rax; -// int 3; -#[repr(packed)] -#[allow(dead_code)] -#[cfg(target_pointer_width = "64")] -struct JmpThunk64 { - mov_rax: [u8; 2], - address: u64, - jmp_rax: [u8; 2], - int3: u8, -} - -impl JmpThunk64 { - fn new(target: unsafe extern "system" fn() -> T) -> Self { - JmpThunk64 { - mov_rax: [0x48, 0xB8], - address: target as u64, - jmp_rax: [0xFF, 0xE0], - int3: 0xcc, - } - } -} - -unsafe fn get_clr_entry_point() -> Option<(*mut *mut c_void, *mut c_void)> { - let mscoree = GetModuleHandleA(b"mscoree\0".as_ptr() as _); - if mscoree == ptr::null_mut() { - return None; - } - let proc = GetProcAddress(mscoree, b"_CorExeMain\0".as_ptr() as _); - if proc == ptr::null_mut() { - return None; - } - COR_EXE_MAIN = mem::transmute(proc); - Some((&mut COR_EXE_MAIN as *mut _ as _, zluda_main_clr as _)) -} - -fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> { +fn get_zluda_dlls_paths() -> Option<(&'static [u8], &'static [u8])> { match get_payload(&PAYLOAD_NVCUDA_GUID) { None => None, Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) { @@ -987,22 +744,17 @@ fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> { } } -fn get_payload(guid: &detours_sys::GUID) -> Option<&'static [u16]> { - let mut module = ptr::null_mut(); - loop { - module = unsafe { detours_sys::DetourEnumerateModules(module) }; - if module == ptr::null_mut() { - return None; - } - let mut size = 0; - let payload_ptr = unsafe { detours_sys::DetourFindPayload(module, guid, &mut size) }; - if payload_ptr != ptr::null_mut() { - return Some(unsafe { - slice::from_raw_parts( - payload_ptr as *const _, - (size as usize) / mem::size_of::(), - ) - }); - } +fn get_payload(guid: &detours_sys::GUID) -> Option<&'static [u8]> { + let mut size = 0; + let payload_ptr = unsafe { detours_sys::DetourFindPayloadEx(guid, &mut size) }; + if payload_ptr != ptr::null_mut() { + Some(unsafe { + slice::from_raw_parts( + payload_ptr as *const _, + (size as usize) / mem::size_of::(), + ) + }) + } else { + None } } -- cgit v1.2.3