diff options
-rw-r--r-- | zluda_dump/src/os_win.rs | 107 | ||||
-rw-r--r-- | zluda_inject/Cargo.toml | 6 | ||||
-rw-r--r-- | zluda_inject/tests/inject.rs | 10 | ||||
-rw-r--r-- | zluda_redirect/src/lib.rs | 88 | ||||
-rw-r--r-- | zluda_redirect/src/payload_guid.rs | 1 |
5 files changed, 137 insertions, 75 deletions
diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index ab4d1d3..2bfc457 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -1,14 +1,13 @@ use std::{
- ffi::{c_void, CStr},
+ ffi::{c_void, CStr, CString, OsString},
mem,
os::raw::c_ushort,
ptr,
};
use std::os::windows::io::AsRawHandle;
-use wchar::wch_c;
use winapi::{
- shared::minwindef::HMODULE,
+ shared::minwindef::{FARPROC, HMODULE},
um::debugapi::OutputDebugStringA,
um::libloaderapi::{GetProcAddress, LoadLibraryW},
};
@@ -17,62 +16,76 @@ use crate::cuda::CUuuid; pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
-
+const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
+lazy_static! {
+ static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() };
+}
include!("../../zluda_redirect/src/payload_guid.rs");
-pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
- let load_lib = if is_detoured() {
- match get_non_detoured_load_library() {
- Some(load_lib) => load_lib,
- None => return ptr::null_mut(),
- }
- } else {
- LoadLibraryW
- };
- let libcuda_path_uf16 = libcuda_path
- .encode_utf16()
- .chain(std::iter::once(0))
- .collect::<Vec<_>>();
- load_lib(libcuda_path_uf16.as_ptr()) as *mut _
+#[allow(non_snake_case)]
+struct PlatformLibrary {
+ LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
+ GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
}
-unsafe fn is_detoured() -> bool {
- let mut module = ptr::null_mut();
- loop {
- module = detours_sys::DetourEnumerateModules(module);
- if module == ptr::null_mut() {
- break;
- }
- let mut size = 0;
- let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_NVCUDA_GUID, &mut size);
- if payload != ptr::null_mut() {
- return true;
+impl PlatformLibrary {
+ #[allow(non_snake_case)]
+ unsafe fn new() -> Self {
+ let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
+ None => (
+ LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
+ mem::transmute(
+ GetProcAddress
+ as unsafe extern "system" fn(
+ hModule: HMODULE,
+ lpProcName: *const i8,
+ ) -> FARPROC,
+ ),
+ ),
+ Some(zluda_with) => (
+ mem::transmute(GetProcAddress(
+ zluda_with,
+ LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
+ )),
+ mem::transmute(GetProcAddress(
+ zluda_with,
+ GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
+ )),
+ ),
+ };
+ PlatformLibrary {
+ LoadLibraryW,
+ GetProcAddress,
}
}
- false
-}
-unsafe fn get_non_detoured_load_library(
-) -> Option<unsafe extern "system" fn(*const c_ushort) -> HMODULE> {
- let mut module = ptr::null_mut();
- loop {
- module = detours_sys::DetourEnumerateModules(module);
- if module == ptr::null_mut() {
- break;
- }
- let result = GetProcAddress(
- module as *mut _,
- LOAD_LIBRARY_NO_REDIRECT.as_ptr() as *mut _,
- );
- if result != ptr::null_mut() {
- return Some(mem::transmute(result));
+ unsafe fn get_detourer_module() -> Option<HMODULE> {
+ let mut module = ptr::null_mut();
+ loop {
+ module = detours_sys::DetourEnumerateModules(module);
+ if module == ptr::null_mut() {
+ break;
+ }
+ let mut size = 0;
+ let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
+ if payload != ptr::null_mut() {
+ return Some(module as _);
+ }
}
+ None
}
- None
+}
+
+pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
+ let libcuda_path_uf16 = libcuda_path
+ .encode_utf16()
+ .chain(std::iter::once(0))
+ .collect::<Vec<_>>();
+ (PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
- GetProcAddress(handle as *mut _, func.as_ptr()) as *mut _
+ (PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
}
#[macro_export]
diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml index 1181a21..cc9dc9b 100644 --- a/zluda_inject/Cargo.toml +++ b/zluda_inject/Cargo.toml @@ -11,3 +11,9 @@ path = "src/main.rs" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] } detours-sys = { path = "../detours-sys" } + +[dev-dependencies] +# dependency for integration tests +zluda_redirect = { path = "../zluda_redirect" } +# dependency for integration tests +zluda_dump = { path = "../zluda_dump" } diff --git a/zluda_inject/tests/inject.rs b/zluda_inject/tests/inject.rs index 5a19d8a..de5fef8 100644 --- a/zluda_inject/tests/inject.rs +++ b/zluda_inject/tests/inject.rs @@ -3,8 +3,8 @@ use std::{env, io, path::PathBuf, process::Command}; #[test]
fn direct_cuinit() -> io::Result<()> {
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
- let mut zluda_redirect_dll = zluda_with_exe.parent().unwrap().to_path_buf();
- zluda_redirect_dll.push("zluda_redirect.dll");
+ let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();
+ zluda_dump_dll.push("zluda_dump.dll");
let helpers_dir = env!("HELPERS_OUT_DIR");
let exe_under_test = format!(
"{}{}direct_cuinit.exe",
@@ -12,11 +12,9 @@ fn direct_cuinit() -> io::Result<()> { std::path::MAIN_SEPARATOR
);
let mut test_cmd = Command::new(&zluda_with_exe);
- test_cmd
- .arg(&zluda_redirect_dll)
- .arg("--")
- .arg(&exe_under_test);
+ let test_cmd = test_cmd.arg(&zluda_dump_dll).arg("--").arg(&exe_under_test);
let test_output = test_cmd.output()?;
+ assert!(test_output.status.success());
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
assert!(stderr_text.contains("ZLUDA_DUMP"));
Ok(())
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index d695ff7..f2d2739 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -4,6 +4,7 @@ extern crate detours_sys; extern crate winapi; use std::{ + collections::HashMap, ffi::{c_void, CStr}, mem, os::raw::c_uint, @@ -61,9 +62,13 @@ static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None; static mut ZLUDA_ML_PATH_UTF8: Vec<u8> = Vec::new(); static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None; static mut CURRENT_MODULE_FILENAME: Vec<u8> = Vec::new(); -static mut DETOUR_DETACH: Option<DetourDetachGuard> = None; +static mut DETOUR_STATE: Option<DetourDetachGuard> = None; const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801; +#[no_mangle] +#[used] +pub static ZLUDA_REDIRECT: () = (); + static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE = LoadLibraryA; @@ -150,6 +155,24 @@ static mut CREATE_PROCESS_WITH_LOGON_W: unsafe extern "system" fn( #[no_mangle] #[allow(non_snake_case)] +unsafe extern "system" fn ZludaGetProcAddress_NoRedirect( + hModule: HMODULE, + lpProcName: LPCSTR, +) -> FARPROC { + if let Some(detour_guard) = &DETOUR_STATE { + if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule { + let proc_name = CStr::from_ptr(lpProcName); + return match detour_guard.overriden_cuda_fns.get(proc_name) { + Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn), + None => ptr::null_mut(), + }; + } + } + GetProcAddress(hModule, lpProcName) +} + +#[no_mangle] +#[allow(non_snake_case)] unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -> HMODULE { (LOAD_LIBRARY_W)(lpLibFileName) } @@ -361,7 +384,9 @@ struct DetourDetachGuard { state: DetourUndoState, suspended_threads: Vec<*mut c_void>, // First element is the original fn, second is the new fn - overriden_functions: Vec<(*mut c_void, *mut c_void)>, + overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>, + nvcuda_module: HMODULE, + overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>, } impl DetourDetachGuard { @@ -371,12 +396,16 @@ impl DetourDetachGuard { // also get overriden, so for example ZludaLoadLibraryExW ends calling // itself recursively until stack overflow exception occurs unsafe fn detour_functions<'a>( - override_fn_pairs: Vec<(*mut c_void, *mut c_void)>, + nvcuda_module: HMODULE, + non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>, + cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>, ) -> Option<Self> { let mut result = DetourDetachGuard { state: DetourUndoState::DoNothing, suspended_threads: Vec::new(), - overriden_functions: override_fn_pairs, + overriden_non_cuda_fns: non_cuda_fns, + nvcuda_module, + overriden_cuda_fns: cuda_fns, }; if DetourTransactionBegin() != NO_ERROR as i32 { return None; @@ -390,24 +419,35 @@ impl DetourDetachGuard { return None; } } - result.overriden_functions.extend_from_slice(&[ - (CREATE_PROCESS_A as _, ZludaCreateProcessA as _), - (CREATE_PROCESS_W as _, ZludaCreateProcessW as _), + result.overriden_non_cuda_fns.extend_from_slice(&[ ( - CREATE_PROCESS_AS_USER_W as _, + &mut CREATE_PROCESS_A as *mut _ as _, + ZludaCreateProcessA as _, + ), + ( + &mut CREATE_PROCESS_W as *mut _ as _, + ZludaCreateProcessW as _, + ), + ( + &mut CREATE_PROCESS_AS_USER_W as *mut _ as _, ZludaCreateProcessAsUserW as _, ), ( - CREATE_PROCESS_WITH_LOGON_W as _, + &mut CREATE_PROCESS_WITH_LOGON_W as *mut _ as _, ZludaCreateProcessWithLogonW as _, ), ( - CREATE_PROCESS_WITH_TOKEN_W as _, + &mut CREATE_PROCESS_WITH_TOKEN_W as *mut _ as _, ZludaCreateProcessWithTokenW as _, ), ]); - for (original_fn, new_fn) in result.overriden_functions.iter_mut() { - if DetourAttach(original_fn as *mut _, *new_fn) != NO_ERROR as i32 { + for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain( + result + .overriden_cuda_fns + .values_mut() + .map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)), + ) { + if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 { return None; } } @@ -633,13 +673,13 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u }; match detach_guard { Some(g) => { - DETOUR_DETACH = Some(g); + DETOUR_STATE = Some(g); TRUE } None => FALSE, } } else if dwReason == DLL_PROCESS_DETACH { - match DETOUR_DETACH.take() { + match DETOUR_STATE.take() { Some(_) => TRUE, None => FALSE, } @@ -691,9 +731,9 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> { } let original_functions = gather_imports(nvcuda_mod); let override_functions = gather_imports(zluda_module); - let mut override_fn_pairs = Vec::with_capacity(original_functions.len()); + let mut override_fn_pairs = HashMap::with_capacity(original_functions.len()); // TODO: optimize - for (original_fn_name, mut original_fn_address) in original_functions { + for (original_fn_name, original_fn_address) in original_functions { let override_fn_address = match override_functions.binary_search_by_key(&original_fn_name, |(name, _)| *name) { Ok(x) => override_functions[x].1, @@ -702,9 +742,12 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> { cuda_unsupported as _ } }; - override_fn_pairs.push((original_fn_address as _, override_fn_address)); + override_fn_pairs.insert( + original_fn_name, + (original_fn_address as _, override_fn_address), + ); } - DetourDetachGuard::detour_functions(override_fn_pairs) + DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs) } unsafe extern "system" fn cuda_unsupported() -> c_uint { @@ -735,7 +778,10 @@ unsafe extern "stdcall" fn gather_imports_impl( #[must_use] unsafe fn attach_load_libary() -> Option<DetourDetachGuard> { let detour_functions = vec![ - (&mut LOAD_LIBRARY_A as *mut _ as _, ZludaLoadLibraryA as _), + ( + &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, + ZludaLoadLibraryA as *mut c_void, + ), (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _), ( &mut LOAD_LIBRARY_EX_A as *mut _ as _, @@ -746,9 +792,7 @@ unsafe fn attach_load_libary() -> Option<DetourDetachGuard> { ZludaLoadLibraryExW as _, ), ]; - let result = DetourDetachGuard::detour_functions(detour_functions); - - result + DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new()) } fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> { diff --git a/zluda_redirect/src/payload_guid.rs b/zluda_redirect/src/payload_guid.rs index 2d7ee6c..968e244 100644 --- a/zluda_redirect/src/payload_guid.rs +++ b/zluda_redirect/src/payload_guid.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)]
const PAYLOAD_NVCUDA_GUID: detours_sys::GUID = detours_sys::GUID {
Data1: 0xC225FC0C,
Data2: 0x00D7,
|