aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_inject
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-02-27 20:55:19 +0100
committerAndrzej Janik <[email protected]>2024-02-11 20:45:51 +0100
commit1b9ba2b2333746c5e2b05a2bf24fa6ec3828dcdf (patch)
tree0b77ca4a41d4f232bd181e2bddc886475c608784 /zluda_inject
parent60d2124a16a7a2a1a6be3707247afe82892a4163 (diff)
downloadZLUDA-1b9ba2b2333746c5e2b05a2bf24fa6ec3828dcdf.tar.gz
ZLUDA-1b9ba2b2333746c5e2b05a2bf24fa6ec3828dcdf.zip
Nobody expects the Red Teamv3
Too many changes to list, but broadly: * Remove Intel GPU support from the compiler * Add AMD GPU support to the compiler * Remove Intel GPU host code * Add AMD GPU host code * More device instructions. From 40 to 68 * More host functions. From 48 to 184 * Add proof of concept implementation of OptiX framework * Add minimal support of cuDNN, cuBLAS, cuSPARSE, cuFFT, NCCL, NVML * Improve ZLUDA launcher for Windows
Diffstat (limited to 'zluda_inject')
-rw-r--r--zluda_inject/Cargo.toml18
-rw-r--r--zluda_inject/build.rs96
-rw-r--r--zluda_inject/src/bin.rs326
-rw-r--r--zluda_inject/tests/helpers/direct_cuinit.rs9
-rw-r--r--zluda_inject/tests/helpers/do_cuinit.rs10
-rw-r--r--zluda_inject/tests/helpers/do_cuinit_early.rs10
-rw-r--r--zluda_inject/tests/helpers/do_cuinit_late.rs23
-rw-r--r--zluda_inject/tests/helpers/do_cuinit_late_clr.cs34
-rw-r--r--zluda_inject/tests/helpers/do_cuinit_late_clr.exebin0 -> 4608 bytes
-rw-r--r--zluda_inject/tests/helpers/indirect_cuinit.rs16
-rw-r--r--zluda_inject/tests/helpers/nvcuda.libbin0 -> 71050 bytes
-rw-r--r--zluda_inject/tests/helpers/query_exe.rs11
-rw-r--r--zluda_inject/tests/helpers/subprocess.rs10
-rw-r--r--zluda_inject/tests/inject.rs67
14 files changed, 559 insertions, 71 deletions
diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml
index 1181a21..c1a7066 100644
--- a/zluda_inject/Cargo.toml
+++ b/zluda_inject/Cargo.toml
@@ -5,9 +5,23 @@ authors = ["Andrzej Janik <[email protected]>"]
edition = "2018"
[[bin]]
-name = "zluda_with"
+name = "zluda"
path = "src/main.rs"
[target.'cfg(windows)'.dependencies]
-winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] }
+winapi = { version = "0.3.9", features = ["jobapi", "jobapi2", "processenv", "processthreadsapi", "synchapi", "winbase", "std"] }
+tempfile = "3"
+argh = "0.1"
detours-sys = { path = "../detours-sys" }
+
+[dev-dependencies]
+# all of those are used in integration tests
+zluda_redirect = { path = "../zluda_redirect" }
+zluda_dump = { path = "../zluda_dump" }
+zluda_ml = { path = "../zluda_ml" }
+
+[build-dependencies]
+embed-manifest = "1.3.1"
+
+[package.metadata.zluda]
+windows_only = true
diff --git a/zluda_inject/build.rs b/zluda_inject/build.rs
new file mode 100644
index 0000000..591815a
--- /dev/null
+++ b/zluda_inject/build.rs
@@ -0,0 +1,96 @@
+use embed_manifest::{embed_manifest, new_manifest};
+use std::{
+ env::{self, VarError},
+ fs::{self, DirEntry},
+ io,
+ path::{self, PathBuf},
+ process::Command,
+};
+
+fn main() -> Result<(), VarError> {
+ if std::env::var_os("CARGO_CFG_WINDOWS").is_some() {
+ embed_manifest(new_manifest("zluda_with")).expect("unable to embed manifest file");
+ }
+ println!("cargo:rerun-if-changed=build.rs");
+ if env::var("PROFILE")? != "debug" {
+ return Ok(());
+ }
+ if env::var("CARGO_CFG_TARGET_OS")? != "windows" {
+ return Ok(());
+ }
+ let rustc_exe = env::var("RUSTC")?;
+ let out_dir = env::var("OUT_DIR")?;
+ let target = env::var("TARGET")?;
+ let is_msvc = env::var("CARGO_CFG_TARGET_ENV")? == "msvc";
+ let opt_level = env::var("OPT_LEVEL")?;
+ let debug = str::parse::<bool>(env::var("DEBUG")?.as_str()).unwrap();
+ let mut helpers_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
+ helpers_dir.push("tests");
+ helpers_dir.push("helpers");
+ let helpers_dir_as_string = helpers_dir.to_string_lossy();
+ println!("cargo:rerun-if-changed={}", helpers_dir_as_string);
+ for rust_file in fs::read_dir(&helpers_dir).unwrap().filter_map(rust_file) {
+ let full_file_path = format!(
+ "{}{}{}",
+ helpers_dir_as_string,
+ path::MAIN_SEPARATOR,
+ rust_file
+ );
+ let mut rustc_cmd = Command::new(&*rustc_exe);
+ if debug {
+ rustc_cmd.arg("-g");
+ }
+ rustc_cmd.arg(format!("-Lnative={}", helpers_dir_as_string));
+ if !is_msvc {
+ // HACK ALERT
+ // I have no idea why the extra library below has to be linked
+ rustc_cmd.arg(r"-lucrt");
+ } else {
+ // For some reason rustc emits foobar.dll.lib and then expects foobar.lib
+ let mut implib_path = PathBuf::from(&out_dir);
+ let implib = PathBuf::from(rust_file);
+ implib_path.push(format!(
+ "{}.lib",
+ implib.file_stem().unwrap().to_string_lossy()
+ ));
+ let link_args = format!("link-args=/IMPLIB:{}", implib_path.as_path().display());
+ rustc_cmd.args(["-C", link_args.as_str()]);
+ }
+ rustc_cmd
+ .arg("-ldylib=nvcuda")
+ .arg("-C")
+ .arg(format!("opt-level={}", opt_level))
+ .arg("-L")
+ .arg(format!("{}", out_dir))
+ .arg("--out-dir")
+ .arg(format!("{}", out_dir))
+ .arg("--target")
+ .arg(format!("{}", target))
+ .arg(full_file_path);
+ assert!(rustc_cmd.status().unwrap().success());
+ }
+ std::fs::copy(
+ format!(
+ "{}{}do_cuinit_late_clr.exe",
+ helpers_dir_as_string,
+ path::MAIN_SEPARATOR
+ ),
+ format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
+ )
+ .unwrap();
+ println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
+ Ok(())
+}
+
+fn rust_file(entry: io::Result<DirEntry>) -> Option<String> {
+ entry.ok().and_then(|e| {
+ let os_file_name = e.file_name();
+ let file_name = os_file_name.to_string_lossy();
+ let is_file = e.file_type().ok().map(|t| t.is_file()).unwrap_or(false);
+ if is_file && file_name.ends_with(".rs") {
+ Some(file_name.to_string())
+ } else {
+ None
+ }
+ })
+}
diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs
index ce83fe9..df664cf 100644
--- a/zluda_inject/src/bin.rs
+++ b/zluda_inject/src/bin.rs
@@ -1,10 +1,14 @@
-use std::mem;
-use std::path::Path;
-use std::ptr;
-use std::{env, ops::Deref};
+use std::env;
+use std::os::windows;
+use std::os::windows::ffi::OsStrExt;
use std::{error::Error, process};
+use std::{fs, io, ptr};
+use std::{mem, path::PathBuf};
+use argh::FromArgs;
use mem::size_of_val;
+use tempfile::TempDir;
+use winapi::um::processenv::SearchPathW;
use winapi::um::{
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
processthreadsapi::{GetExitCodeProcess, ResumeThread},
@@ -19,26 +23,62 @@ use winapi::um::{
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
-static ZLUDA_DLL: &'static str = "nvcuda.dll";
+static NVCUDA_DLL: &'static str = "nvcuda.dll";
+static NVML_DLL: &'static str = "nvml.dll";
+static NVAPI_DLL: &'static str = "nvapi64.dll";
+static NVOPTIX_DLL: &'static str = "optix.6.6.0.dll";
include!("../../zluda_redirect/src/payload_guid.rs");
+#[derive(FromArgs)]
+/// Launch application with custom CUDA libraries
+struct ProgramArguments {
+ /// DLL to be injected instead of system nvcuda.dll. If not provided {0}, will use nvcuda.dll from its own directory
+ #[argh(option)]
+ nvcuda: Option<PathBuf>,
+
+ /// DLL to be injected instead of system nvml.dll. If not provided {0}, will use nvml.dll from its own directory
+ #[argh(option)]
+ nvml: Option<PathBuf>,
+
+ /// DLL to be injected instead of system nvapi64.dll. If not provided, no injection will take place
+ #[argh(option)]
+ nvapi: Option<PathBuf>,
+
+ /// DLL to be injected instead of system nvoptix.dll. If not provided, no injection will take place
+ #[argh(option)]
+ nvoptix: Option<PathBuf>,
+
+ /// executable to be injected with custom CUDA libraries
+ #[argh(positional)]
+ exe: String,
+
+ /// arguments to the executable
+ #[argh(positional)]
+ args: Vec<String>,
+}
+
pub fn main_impl() -> Result<(), Box<dyn Error>> {
- let args = env::args().collect::<Vec<_>>();
- if args.len() <= 1 {
- print_help_and_exit();
- }
- let injector_path = env::current_exe()?;
- let injector_dir = injector_path.parent().unwrap();
- let redirect_path = create_redirect_path(injector_dir);
- let (mut inject_path, cmd) = create_inject_path(&args[1..], injector_dir);
- let mut cmd_line = construct_command_line(cmd);
+ let raw_args = argh::from_env::<ProgramArguments>();
+ let normalized_args = NormalizedArguments::new(raw_args)?;
+ let mut environment = Environment::setup(normalized_args)?;
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
+ let mut dlls_to_inject = vec![
+ environment.nvcuda_path_zero_terminated.as_ptr() as _,
+ environment.nvml_path_zero_terminated.as_ptr() as *const i8,
+ environment.redirect_path_zero_terminated.as_ptr() as _,
+ ];
+ if let Some(ref nvapi) = environment.nvapi_path_zero_terminated {
+ dlls_to_inject.push(nvapi.as_ptr() as _);
+ }
+ if let Some(ref nvoptix) = environment.nvoptix_path_zero_terminated {
+ dlls_to_inject.push(nvoptix.as_ptr() as _);
+ }
os_call!(
- detours_sys::DetourCreateProcessWithDllExW(
+ detours_sys::DetourCreateProcessWithDllsW(
ptr::null(),
- cmd_line.as_mut_ptr(),
+ environment.winapi_command_line_zero_terminated.as_mut_ptr(),
ptr::null_mut(),
ptr::null_mut(),
0,
@@ -47,7 +87,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
ptr::null(),
&mut startup_info as *mut _,
&mut proc_info as *mut _,
- redirect_path.as_ptr() as *const i8,
+ dlls_to_inject.len() as u32,
+ dlls_to_inject.as_mut_ptr(),
Option::None
),
|x| x != 0
@@ -56,12 +97,43 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
- &PAYLOAD_GUID,
- inject_path.as_mut_ptr() as *mut _,
- (inject_path.len() * mem::size_of::<u16>()) as u32
+ &PAYLOAD_NVCUDA_GUID,
+ environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
+ environment.nvcuda_path_zero_terminated.len() as u32
+ ),
+ |x| x != 0
+ );
+ os_call!(
+ detours_sys::DetourCopyPayloadToProcess(
+ proc_info.hProcess,
+ &PAYLOAD_NVML_GUID,
+ environment.nvml_path_zero_terminated.as_ptr() as *mut _,
+ environment.nvml_path_zero_terminated.len() as u32
),
|x| x != 0
);
+ if let Some(nvapi) = environment.nvapi_path_zero_terminated {
+ os_call!(
+ detours_sys::DetourCopyPayloadToProcess(
+ proc_info.hProcess,
+ &PAYLOAD_NVAPI_GUID,
+ nvapi.as_ptr() as *mut _,
+ nvapi.len() as u32
+ ),
+ |x| x != 0
+ );
+ }
+ if let Some(nvoptix) = environment.nvoptix_path_zero_terminated {
+ os_call!(
+ detours_sys::DetourCopyPayloadToProcess(
+ proc_info.hProcess,
+ &PAYLOAD_NVOPTIX_GUID,
+ nvoptix.as_ptr() as *mut _,
+ nvoptix.len() as u32
+ ),
+ |x| x != 0
+ );
+ }
os_call!(ResumeThread(proc_info.hThread), |x| x as i32 != -1);
os_call!(WaitForSingleObject(proc_info.hProcess, INFINITE), |x| x
!= WAIT_FAILED);
@@ -73,6 +145,168 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
process::exit(child_exit_code as i32)
}
+struct NormalizedArguments {
+ nvcuda_path: PathBuf,
+ nvml_path: PathBuf,
+ nvapi_path: Option<PathBuf>,
+ nvoptix_path: Option<PathBuf>,
+ redirect_path: PathBuf,
+ winapi_command_line_zero_terminated: Vec<u16>,
+}
+
+impl NormalizedArguments {
+ fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
+ let current_exe = env::current_exe()?;
+ let nvcuda_path =
+ Self::get_absolute_path_or_default(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?;
+ let nvml_path = Self::get_absolute_path_or_default(&current_exe, prog_args.nvml, NVML_DLL)?;
+ let nvapi_path = prog_args.nvapi.map(Self::get_absolute_path).transpose()?;
+ let nvoptix_path = prog_args.nvoptix.map(Self::get_absolute_path).transpose()?;
+ let winapi_command_line_zero_terminated =
+ construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
+ let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
+ redirect_path.push(REDIRECT_DLL);
+ Ok(Self {
+ nvcuda_path,
+ nvml_path,
+ nvapi_path,
+ nvoptix_path,
+ redirect_path,
+ winapi_command_line_zero_terminated,
+ })
+ }
+
+ const WIN_MAX_PATH: usize = 260;
+
+ fn get_absolute_path_or_default(
+ current_exe: &PathBuf,
+ dll: Option<PathBuf>,
+ default: &str,
+ ) -> Result<PathBuf, Box<dyn Error>> {
+ if let Some(dll) = dll {
+ Self::get_absolute_path(dll)
+ } else {
+ let mut dll_path = current_exe.parent().unwrap().to_path_buf();
+ dll_path.push(default);
+ Ok(dll_path)
+ }
+ }
+
+ fn get_absolute_path(dll: PathBuf) -> Result<PathBuf, Box<dyn Error>> {
+ Ok(if dll.is_absolute() {
+ dll
+ } else {
+ let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
+ let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
+ dll_utf16.push(0);
+ loop {
+ let copied_len = os_call!(
+ SearchPathW(
+ ptr::null_mut(),
+ dll_utf16.as_ptr(),
+ ptr::null(),
+ full_dll_path.len() as u32,
+ full_dll_path.as_mut_ptr(),
+ ptr::null_mut()
+ ),
+ |x| x != 0
+ ) as usize;
+ if copied_len > full_dll_path.len() {
+ full_dll_path.resize(copied_len + 1, 0);
+ } else {
+ full_dll_path.truncate(copied_len);
+ break;
+ }
+ }
+ PathBuf::from(String::from_utf16_lossy(&full_dll_path))
+ })
+ }
+}
+
+struct Environment {
+ nvcuda_path_zero_terminated: String,
+ nvml_path_zero_terminated: String,
+ nvapi_path_zero_terminated: Option<String>,
+ nvoptix_path_zero_terminated: Option<String>,
+ redirect_path_zero_terminated: String,
+ winapi_command_line_zero_terminated: Vec<u16>,
+ _temp_dir: TempDir,
+}
+
+// This structs represents "environment". By environment we mean all paths
+// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
+// directory which contains nvcuda.dll
+impl Environment {
+ fn setup(args: NormalizedArguments) -> io::Result<Self> {
+ let _temp_dir = TempDir::new()?;
+ let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
+ args.nvcuda_path,
+ &_temp_dir,
+ NVCUDA_DLL,
+ )?);
+ let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
+ args.nvml_path,
+ &_temp_dir,
+ NVML_DLL,
+ )?);
+ let nvapi_path_zero_terminated = args
+ .nvapi_path
+ .map(|nvapi| {
+ Ok::<_, io::Error>(Self::zero_terminate(Self::copy_to_correct_name(
+ nvapi, &_temp_dir, NVAPI_DLL,
+ )?))
+ })
+ .transpose()?;
+ let nvoptix_path_zero_terminated = args
+ .nvoptix_path
+ .map(|nvoptix| {
+ Ok::<_, io::Error>(Self::zero_terminate(Self::copy_to_correct_name(
+ nvoptix,
+ &_temp_dir,
+ NVOPTIX_DLL,
+ )?))
+ })
+ .transpose()?;
+ let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
+ Ok(Self {
+ nvcuda_path_zero_terminated,
+ nvml_path_zero_terminated,
+ nvapi_path_zero_terminated,
+ nvoptix_path_zero_terminated,
+ redirect_path_zero_terminated,
+ winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
+ _temp_dir,
+ })
+ }
+
+ fn copy_to_correct_name(
+ path_buf: PathBuf,
+ temp_dir: &TempDir,
+ correct_name: &str,
+ ) -> io::Result<PathBuf> {
+ let file_name = path_buf.file_name().unwrap();
+ if file_name == correct_name {
+ Ok(path_buf)
+ } else {
+ let mut temp_file_path = temp_dir.path().to_path_buf();
+ temp_file_path.push(correct_name);
+ match windows::fs::symlink_file(&path_buf, &temp_file_path) {
+ Ok(()) => {}
+ Err(_) => {
+ fs::copy(&path_buf, &temp_file_path)?;
+ }
+ }
+ Ok(temp_file_path)
+ }
+ }
+
+ fn zero_terminate(p: PathBuf) -> String {
+ let mut s = p.to_string_lossy().to_string();
+ s.push('\0');
+ s
+ }
+}
+
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
!= ptr::null_mut());
@@ -91,29 +325,11 @@ fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
Ok(())
}
-fn print_help_and_exit() -> ! {
- let current_exe = env::current_exe().unwrap();
- let exe_name = current_exe.file_name().unwrap().to_string_lossy();
- println!(
- "USAGE:
- {0} -- <EXE> [ARGS]...
- {0} <DLL> -- <EXE> [ARGS]...
-ARGS:
- <DLL> DLL to ne injected instead of system nvcuda.dll, if not provided
- will use nvcuda.dll from the directory where {0} is located
- <EXE> Path to the executable to be injected with <DLL>
- <ARGS>... Arguments that will be passed to <EXE>
-",
- exe_name
- );
- process::exit(1)
-}
-
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
-fn construct_command_line(args: &[String]) -> Vec<u16> {
+fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
let mut cmd_line = Vec::new();
- let args_len = args.len();
- for (idx, arg) in args.iter().enumerate() {
+ let args_len = args.size_hint().0;
+ for (idx, arg) in args.enumerate() {
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
cmd_line.extend(arg.encode_utf16());
} else {
@@ -164,31 +380,3 @@ fn construct_command_line(args: &[String]) -> Vec<u16> {
cmd_line.push(0);
cmd_line
}
-
-fn create_redirect_path(injector_dir: &Path) -> Vec<u8> {
- let mut injector_dir = injector_dir.to_path_buf();
- injector_dir.push(REDIRECT_DLL);
- let mut result = injector_dir.to_string_lossy().into_owned().into_bytes();
- result.push(0);
- result
-}
-
-fn create_inject_path<'a>(args: &'a [String], injector_dir: &Path) -> (Vec<u16>, &'a [String]) {
- if args.get(0).map(Deref::deref) == Some("--") {
- let mut injector_dir = injector_dir.to_path_buf();
- injector_dir.push(ZLUDA_DLL);
- let mut result = injector_dir
- .to_string_lossy()
- .as_ref()
- .encode_utf16()
- .collect::<Vec<_>>();
- result.push(0);
- (result, &args[1..])
- } else if args.get(1).map(Deref::deref) == Some("--") {
- let mut dll_path = args[0].encode_utf16().collect::<Vec<_>>();
- dll_path.push(0);
- (dll_path, &args[2..])
- } else {
- print_help_and_exit()
- }
-}
diff --git a/zluda_inject/tests/helpers/direct_cuinit.rs b/zluda_inject/tests/helpers/direct_cuinit.rs
new file mode 100644
index 0000000..8341a60
--- /dev/null
+++ b/zluda_inject/tests/helpers/direct_cuinit.rs
@@ -0,0 +1,9 @@
+#![crate_type = "bin"]
+
+extern "system" {
+ fn cuInit(flags: u32) -> u32;
+}
+
+fn main() {
+ unsafe { cuInit(0) };
+}
diff --git a/zluda_inject/tests/helpers/do_cuinit.rs b/zluda_inject/tests/helpers/do_cuinit.rs
new file mode 100644
index 0000000..468d56c
--- /dev/null
+++ b/zluda_inject/tests/helpers/do_cuinit.rs
@@ -0,0 +1,10 @@
+#![crate_type = "cdylib"]
+
+extern "system" {
+ fn cuInit(flags: u32) -> u32;
+}
+
+#[no_mangle]
+unsafe extern "system" fn do_cuinit(flags: u32) -> u32 {
+ cuInit(flags)
+}
diff --git a/zluda_inject/tests/helpers/do_cuinit_early.rs b/zluda_inject/tests/helpers/do_cuinit_early.rs
new file mode 100644
index 0000000..9743f4a
--- /dev/null
+++ b/zluda_inject/tests/helpers/do_cuinit_early.rs
@@ -0,0 +1,10 @@
+#![crate_type = "bin"]
+
+#[link(name = "do_cuinit")]
+extern "system" {
+ fn do_cuinit(flags: u32) -> u32;
+}
+
+fn main() {
+ unsafe { do_cuinit(0) };
+}
diff --git a/zluda_inject/tests/helpers/do_cuinit_late.rs b/zluda_inject/tests/helpers/do_cuinit_late.rs
new file mode 100644
index 0000000..ab3516d
--- /dev/null
+++ b/zluda_inject/tests/helpers/do_cuinit_late.rs
@@ -0,0 +1,23 @@
+#![crate_type = "bin"]
+
+use std::ffi::c_void;
+use std::mem;
+use std::env;
+use std::path::PathBuf;
+use std::ffi::CString;
+
+extern "system" {
+ fn LoadLibraryA(lpFileName: *const i8) -> *mut c_void;
+ fn GetProcAddress(hModule: *mut c_void, lpProcName: *const u8) -> *mut c_void;
+}
+
+fn main() {
+ let current_exe = env::current_exe().unwrap();
+ let mut dll = PathBuf::from(current_exe.parent().unwrap());
+ dll.push("do_cuinit.dll");
+ let dll_cstring = CString::new(dll.to_str().unwrap()).unwrap();
+ let nvcuda = unsafe { LoadLibraryA(dll_cstring.as_ptr()) };
+ let cu_init = unsafe { GetProcAddress(nvcuda, b"do_cuinit\0".as_ptr()) };
+ let cu_init = unsafe { mem::transmute::<_, unsafe extern "system" fn(u32) -> u32>(cu_init) };
+ unsafe { cu_init(0) };
+}
diff --git a/zluda_inject/tests/helpers/do_cuinit_late_clr.cs b/zluda_inject/tests/helpers/do_cuinit_late_clr.cs
new file mode 100644
index 0000000..666c237
--- /dev/null
+++ b/zluda_inject/tests/helpers/do_cuinit_late_clr.cs
@@ -0,0 +1,34 @@
+using System;
+using System.IO;
+using System.Reflection;
+using System.Runtime.InteropServices;
+
+namespace Zluda
+{
+ class Program
+ {
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ private delegate int CuInit(int flags);
+
+ static int Main(string[] args)
+ {
+ DirectoryInfo exeDirectory = Directory.GetParent(Assembly.GetEntryAssembly().Location);
+ string dllPath = Path.Combine(exeDirectory.ToString(), "do_cuinit.dll");
+ IntPtr nvcuda = NativeMethods.LoadLibrary(dllPath);
+ if (nvcuda == IntPtr.Zero)
+ return 1;
+ IntPtr doCuinitPtr = NativeMethods.GetProcAddress(nvcuda, "do_cuinit");
+ CuInit cuinit = (CuInit)Marshal.GetDelegateForFunctionPointer(doCuinitPtr, typeof(CuInit));
+ return cuinit(0);
+ }
+ }
+
+ static class NativeMethods
+ {
+ [DllImport("kernel32.dll")]
+ public static extern IntPtr LoadLibrary(string dllToLoad);
+
+ [DllImport("kernel32.dll")]
+ public static extern IntPtr GetProcAddress(IntPtr hModule, string procedureName);
+ }
+} \ No newline at end of file
diff --git a/zluda_inject/tests/helpers/do_cuinit_late_clr.exe b/zluda_inject/tests/helpers/do_cuinit_late_clr.exe
new file mode 100644
index 0000000..b8e4975
--- /dev/null
+++ b/zluda_inject/tests/helpers/do_cuinit_late_clr.exe
Binary files differ
diff --git a/zluda_inject/tests/helpers/indirect_cuinit.rs b/zluda_inject/tests/helpers/indirect_cuinit.rs
new file mode 100644
index 0000000..f254dc1
--- /dev/null
+++ b/zluda_inject/tests/helpers/indirect_cuinit.rs
@@ -0,0 +1,16 @@
+#![crate_type = "bin"]
+
+use std::ffi::c_void;
+use std::mem;
+
+extern "system" {
+ fn LoadLibraryA(lpFileName: *const u8) -> *mut c_void;
+ fn GetProcAddress(hModule: *mut c_void, lpProcName: *const u8) -> *mut c_void;
+}
+
+fn main() {
+ let nvcuda = unsafe { LoadLibraryA(b"C:\\Windows\\System32\\nvcuda.dll\0".as_ptr()) };
+ let cu_init = unsafe { GetProcAddress(nvcuda, b"cuInit\0".as_ptr()) };
+ let cu_init = unsafe { mem::transmute::<_, unsafe extern "system" fn(u32) -> u32>(cu_init) };
+ unsafe { cu_init(0) };
+}
diff --git a/zluda_inject/tests/helpers/nvcuda.lib b/zluda_inject/tests/helpers/nvcuda.lib
new file mode 100644
index 0000000..b793c56
--- /dev/null
+++ b/zluda_inject/tests/helpers/nvcuda.lib
Binary files differ
diff --git a/zluda_inject/tests/helpers/query_exe.rs b/zluda_inject/tests/helpers/query_exe.rs
new file mode 100644
index 0000000..057de39
--- /dev/null
+++ b/zluda_inject/tests/helpers/query_exe.rs
@@ -0,0 +1,11 @@
+#![crate_type = "bin"]
+
+use std::io;
+use std::process::Command;
+
+fn main() -> io::Result<()> {
+ let status = Command::new("query.exe").arg("session").status()?;
+ // App returns 1 on my machine
+ assert_eq!(status.code(), Some(1));
+ Ok(())
+}
diff --git a/zluda_inject/tests/helpers/subprocess.rs b/zluda_inject/tests/helpers/subprocess.rs
new file mode 100644
index 0000000..3d1588c
--- /dev/null
+++ b/zluda_inject/tests/helpers/subprocess.rs
@@ -0,0 +1,10 @@
+#![crate_type = "bin"]
+
+use std::io;
+use std::process::Command;
+
+fn main() -> io::Result<()> {
+ let status = Command::new("direct_cuinit.exe").status()?;
+ assert!(status.success());
+ Ok(())
+}
diff --git a/zluda_inject/tests/inject.rs b/zluda_inject/tests/inject.rs
new file mode 100644
index 0000000..78f26f1
--- /dev/null
+++ b/zluda_inject/tests/inject.rs
@@ -0,0 +1,67 @@
+use std::{
+ env, io,
+ path::PathBuf,
+ process::{Command, Output},
+};
+
+#[test]
+fn direct_cuinit() -> io::Result<()> {
+ run_process_and_check_for_zluda_dump("direct_cuinit")
+}
+
+#[test]
+fn do_cuinit_early() -> io::Result<()> {
+ run_process_and_check_for_zluda_dump("do_cuinit_early")
+}
+
+#[test]
+fn do_cuinit_late() -> io::Result<()> {
+ run_process_and_check_for_zluda_dump("do_cuinit_late")
+}
+
+#[test]
+fn do_cuinit_late_clr() -> io::Result<()> {
+ run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
+}
+
+#[test]
+fn indirect_cuinit() -> io::Result<()> {
+ run_process_and_check_for_zluda_dump("indirect_cuinit")
+}
+
+#[test]
+fn subprocess() -> io::Result<()> {
+ run_process_and_check_for_zluda_dump("subprocess")
+}
+
+#[test]
+fn query_exe() -> io::Result<()> {
+ let process_output = run_process("query_exe")?;
+ let stdout_text = String::from_utf8(process_output.stdout).unwrap();
+ assert!(stdout_text.contains("SESSIONNAME"));
+ Ok(())
+}
+
+fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
+ let process_output = run_process(name)?;
+ let stderr_text = String::from_utf8(process_output.stderr).unwrap();
+ assert!(stderr_text.contains("ZLUDA_DUMP"));
+ Ok(())
+}
+
+fn run_process(name: &'static str) -> io::Result<Output> {
+ let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
+ let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
+ zluda_dump_dll.push("zluda_dump.dll");
+ let helpers_dir = env!("HELPERS_OUT_DIR");
+ let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
+ let mut test_cmd = Command::new(&zluda_with_exe);
+ let test_cmd = test_cmd
+ .arg("--nvcuda")
+ .arg(&zluda_dump_dll)
+ .arg("--")
+ .arg(&exe_under_test);
+ let test_output = test_cmd.output()?;
+ assert!(test_output.status.success());
+ Ok(test_output)
+}