aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-12-01 23:08:07 +0100
committerAndrzej Janik <[email protected]>2021-12-01 23:08:07 +0100
commit400feaf015fc0084608479df58d6cdeccb87986b (patch)
tree3434eb1974bb6135b8e58e232235a92163477bcd
parentfd1c13560f29e9f6e43d19b5cbe48dcd1351bcd6 (diff)
downloadZLUDA-400feaf015fc0084608479df58d6cdeccb87986b.tar.gz
ZLUDA-400feaf015fc0084608479df58d6cdeccb87986b.zip
Add test for injecting app that directly uses nvcuda
-rw-r--r--zluda_dump/src/os_win.rs107
-rw-r--r--zluda_inject/Cargo.toml6
-rw-r--r--zluda_inject/tests/inject.rs10
-rw-r--r--zluda_redirect/src/lib.rs88
-rw-r--r--zluda_redirect/src/payload_guid.rs1
5 files changed, 137 insertions, 75 deletions
diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs
index ab4d1d3..2bfc457 100644
--- a/zluda_dump/src/os_win.rs
+++ b/zluda_dump/src/os_win.rs
@@ -1,14 +1,13 @@
use std::{
- ffi::{c_void, CStr},
+ ffi::{c_void, CStr, CString, OsString},
mem,
os::raw::c_ushort,
ptr,
};
use std::os::windows::io::AsRawHandle;
-use wchar::wch_c;
use winapi::{
- shared::minwindef::HMODULE,
+ shared::minwindef::{FARPROC, HMODULE},
um::debugapi::OutputDebugStringA,
um::libloaderapi::{GetProcAddress, LoadLibraryW},
};
@@ -17,62 +16,76 @@ use crate::cuda::CUuuid;
pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
-
+const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
+lazy_static! {
+ static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() };
+}
include!("../../zluda_redirect/src/payload_guid.rs");
-pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
- let load_lib = if is_detoured() {
- match get_non_detoured_load_library() {
- Some(load_lib) => load_lib,
- None => return ptr::null_mut(),
- }
- } else {
- LoadLibraryW
- };
- let libcuda_path_uf16 = libcuda_path
- .encode_utf16()
- .chain(std::iter::once(0))
- .collect::<Vec<_>>();
- load_lib(libcuda_path_uf16.as_ptr()) as *mut _
+#[allow(non_snake_case)]
+struct PlatformLibrary {
+ LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
+ GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
}
-unsafe fn is_detoured() -> bool {
- let mut module = ptr::null_mut();
- loop {
- module = detours_sys::DetourEnumerateModules(module);
- if module == ptr::null_mut() {
- break;
- }
- let mut size = 0;
- let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_NVCUDA_GUID, &mut size);
- if payload != ptr::null_mut() {
- return true;
+impl PlatformLibrary {
+ #[allow(non_snake_case)]
+ unsafe fn new() -> Self {
+ let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
+ None => (
+ LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
+ mem::transmute(
+ GetProcAddress
+ as unsafe extern "system" fn(
+ hModule: HMODULE,
+ lpProcName: *const i8,
+ ) -> FARPROC,
+ ),
+ ),
+ Some(zluda_with) => (
+ mem::transmute(GetProcAddress(
+ zluda_with,
+ LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
+ )),
+ mem::transmute(GetProcAddress(
+ zluda_with,
+ GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
+ )),
+ ),
+ };
+ PlatformLibrary {
+ LoadLibraryW,
+ GetProcAddress,
}
}
- false
-}
-unsafe fn get_non_detoured_load_library(
-) -> Option<unsafe extern "system" fn(*const c_ushort) -> HMODULE> {
- let mut module = ptr::null_mut();
- loop {
- module = detours_sys::DetourEnumerateModules(module);
- if module == ptr::null_mut() {
- break;
- }
- let result = GetProcAddress(
- module as *mut _,
- LOAD_LIBRARY_NO_REDIRECT.as_ptr() as *mut _,
- );
- if result != ptr::null_mut() {
- return Some(mem::transmute(result));
+ unsafe fn get_detourer_module() -> Option<HMODULE> {
+ let mut module = ptr::null_mut();
+ loop {
+ module = detours_sys::DetourEnumerateModules(module);
+ if module == ptr::null_mut() {
+ break;
+ }
+ let mut size = 0;
+ let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
+ if payload != ptr::null_mut() {
+ return Some(module as _);
+ }
}
+ None
}
- None
+}
+
+pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
+ let libcuda_path_uf16 = libcuda_path
+ .encode_utf16()
+ .chain(std::iter::once(0))
+ .collect::<Vec<_>>();
+ (PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
- GetProcAddress(handle as *mut _, func.as_ptr()) as *mut _
+ (PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
}
#[macro_export]
diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml
index 1181a21..cc9dc9b 100644
--- a/zluda_inject/Cargo.toml
+++ b/zluda_inject/Cargo.toml
@@ -11,3 +11,9 @@ path = "src/main.rs"
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] }
detours-sys = { path = "../detours-sys" }
+
+[dev-dependencies]
+# dependency for integration tests
+zluda_redirect = { path = "../zluda_redirect" }
+# dependency for integration tests
+zluda_dump = { path = "../zluda_dump" }
diff --git a/zluda_inject/tests/inject.rs b/zluda_inject/tests/inject.rs
index 5a19d8a..de5fef8 100644
--- a/zluda_inject/tests/inject.rs
+++ b/zluda_inject/tests/inject.rs
@@ -3,8 +3,8 @@ use std::{env, io, path::PathBuf, process::Command};
#[test]
fn direct_cuinit() -> io::Result<()> {
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
- let mut zluda_redirect_dll = zluda_with_exe.parent().unwrap().to_path_buf();
- zluda_redirect_dll.push("zluda_redirect.dll");
+ let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
+ zluda_dump_dll.push("zluda_dump.dll");
let helpers_dir = env!("HELPERS_OUT_DIR");
let exe_under_test = format!(
"{}{}direct_cuinit.exe",
@@ -12,11 +12,9 @@ fn direct_cuinit() -> io::Result<()> {
std::path::MAIN_SEPARATOR
);
let mut test_cmd = Command::new(&zluda_with_exe);
- test_cmd
- .arg(&zluda_redirect_dll)
- .arg("--")
- .arg(&exe_under_test);
+ let test_cmd = test_cmd.arg(&zluda_dump_dll).arg("--").arg(&exe_under_test);
let test_output = test_cmd.output()?;
+ assert!(test_output.status.success());
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
assert!(stderr_text.contains("ZLUDA_DUMP"));
Ok(())
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs
index d695ff7..f2d2739 100644
--- a/zluda_redirect/src/lib.rs
+++ b/zluda_redirect/src/lib.rs
@@ -4,6 +4,7 @@ extern crate detours_sys;
extern crate winapi;
use std::{
+ collections::HashMap,
ffi::{c_void, CStr},
mem,
os::raw::c_uint,
@@ -61,9 +62,13 @@ static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
static mut ZLUDA_ML_PATH_UTF8: Vec<u8> = Vec::new();
static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None;
static mut CURRENT_MODULE_FILENAME: Vec<u8> = Vec::new();
-static mut DETOUR_DETACH: Option<DetourDetachGuard> = None;
+static mut DETOUR_STATE: Option<DetourDetachGuard> = None;
const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801;
+#[no_mangle]
+#[used]
+pub static ZLUDA_REDIRECT: () = ();
+
static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE =
LoadLibraryA;
@@ -150,6 +155,24 @@ static mut CREATE_PROCESS_WITH_LOGON_W: unsafe extern "system" fn(
#[no_mangle]
#[allow(non_snake_case)]
+unsafe extern "system" fn ZludaGetProcAddress_NoRedirect(
+ hModule: HMODULE,
+ lpProcName: LPCSTR,
+) -> FARPROC {
+ if let Some(detour_guard) = &DETOUR_STATE {
+ if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule {
+ let proc_name = CStr::from_ptr(lpProcName);
+ return match detour_guard.overriden_cuda_fns.get(proc_name) {
+ Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn),
+ None => ptr::null_mut(),
+ };
+ }
+ }
+ GetProcAddress(hModule, lpProcName)
+}
+
+#[no_mangle]
+#[allow(non_snake_case)]
unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -> HMODULE {
(LOAD_LIBRARY_W)(lpLibFileName)
}
@@ -361,7 +384,9 @@ struct DetourDetachGuard {
state: DetourUndoState,
suspended_threads: Vec<*mut c_void>,
// First element is the original fn, second is the new fn
- overriden_functions: Vec<(*mut c_void, *mut c_void)>,
+ overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
+ nvcuda_module: HMODULE,
+ overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
}
impl DetourDetachGuard {
@@ -371,12 +396,16 @@ impl DetourDetachGuard {
// also get overriden, so for example ZludaLoadLibraryExW ends calling
// itself recursively until stack overflow exception occurs
unsafe fn detour_functions<'a>(
- override_fn_pairs: Vec<(*mut c_void, *mut c_void)>,
+ nvcuda_module: HMODULE,
+ non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
+ cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
) -> Option<Self> {
let mut result = DetourDetachGuard {
state: DetourUndoState::DoNothing,
suspended_threads: Vec::new(),
- overriden_functions: override_fn_pairs,
+ overriden_non_cuda_fns: non_cuda_fns,
+ nvcuda_module,
+ overriden_cuda_fns: cuda_fns,
};
if DetourTransactionBegin() != NO_ERROR as i32 {
return None;
@@ -390,24 +419,35 @@ impl DetourDetachGuard {
return None;
}
}
- result.overriden_functions.extend_from_slice(&[
- (CREATE_PROCESS_A as _, ZludaCreateProcessA as _),
- (CREATE_PROCESS_W as _, ZludaCreateProcessW as _),
+ result.overriden_non_cuda_fns.extend_from_slice(&[
(
- CREATE_PROCESS_AS_USER_W as _,
+ &mut CREATE_PROCESS_A as *mut _ as _,
+ ZludaCreateProcessA as _,
+ ),
+ (
+ &mut CREATE_PROCESS_W as *mut _ as _,
+ ZludaCreateProcessW as _,
+ ),
+ (
+ &mut CREATE_PROCESS_AS_USER_W as *mut _ as _,
ZludaCreateProcessAsUserW as _,
),
(
- CREATE_PROCESS_WITH_LOGON_W as _,
+ &mut CREATE_PROCESS_WITH_LOGON_W as *mut _ as _,
ZludaCreateProcessWithLogonW as _,
),
(
- CREATE_PROCESS_WITH_TOKEN_W as _,
+ &mut CREATE_PROCESS_WITH_TOKEN_W as *mut _ as _,
ZludaCreateProcessWithTokenW as _,
),
]);
- for (original_fn, new_fn) in result.overriden_functions.iter_mut() {
- if DetourAttach(original_fn as *mut _, *new_fn) != NO_ERROR as i32 {
+ for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain(
+ result
+ .overriden_cuda_fns
+ .values_mut()
+ .map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)),
+ ) {
+ if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
return None;
}
}
@@ -633,13 +673,13 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
};
match detach_guard {
Some(g) => {
- DETOUR_DETACH = Some(g);
+ DETOUR_STATE = Some(g);
TRUE
}
None => FALSE,
}
} else if dwReason == DLL_PROCESS_DETACH {
- match DETOUR_DETACH.take() {
+ match DETOUR_STATE.take() {
Some(_) => TRUE,
None => FALSE,
}
@@ -691,9 +731,9 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
}
let original_functions = gather_imports(nvcuda_mod);
let override_functions = gather_imports(zluda_module);
- let mut override_fn_pairs = Vec::with_capacity(original_functions.len());
+ let mut override_fn_pairs = HashMap::with_capacity(original_functions.len());
// TODO: optimize
- for (original_fn_name, mut original_fn_address) in original_functions {
+ for (original_fn_name, original_fn_address) in original_functions {
let override_fn_address =
match override_functions.binary_search_by_key(&original_fn_name, |(name, _)| *name) {
Ok(x) => override_functions[x].1,
@@ -702,9 +742,12 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
cuda_unsupported as _
}
};
- override_fn_pairs.push((original_fn_address as _, override_fn_address));
+ override_fn_pairs.insert(
+ original_fn_name,
+ (original_fn_address as _, override_fn_address),
+ );
}
- DetourDetachGuard::detour_functions(override_fn_pairs)
+ DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs)
}
unsafe extern "system" fn cuda_unsupported() -> c_uint {
@@ -735,7 +778,10 @@ unsafe extern "stdcall" fn gather_imports_impl(
#[must_use]
unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
let detour_functions = vec![
- (&mut LOAD_LIBRARY_A as *mut _ as _, ZludaLoadLibraryA as _),
+ (
+ &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
+ ZludaLoadLibraryA as *mut c_void,
+ ),
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
(
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
@@ -746,9 +792,7 @@ unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
ZludaLoadLibraryExW as _,
),
];
- let result = DetourDetachGuard::detour_functions(detour_functions);
-
- result
+ DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new())
}
fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> {
diff --git a/zluda_redirect/src/payload_guid.rs b/zluda_redirect/src/payload_guid.rs
index 2d7ee6c..968e244 100644
--- a/zluda_redirect/src/payload_guid.rs
+++ b/zluda_redirect/src/payload_guid.rs
@@ -1,3 +1,4 @@
+#[allow(dead_code)]
const PAYLOAD_NVCUDA_GUID: detours_sys::GUID = detours_sys::GUID {
Data1: 0xC225FC0C,
Data2: 0x00D7,