From 400feaf015fc0084608479df58d6cdeccb87986b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 1 Dec 2021 23:08:07 +0100 Subject: Add test for injecting app that directly uses nvcuda --- zluda_dump/src/os_win.rs | 107 +++++++++++++++++++++---------------- zluda_inject/Cargo.toml | 6 +++ zluda_inject/tests/inject.rs | 10 ++-- zluda_redirect/src/lib.rs | 88 ++++++++++++++++++++++-------- zluda_redirect/src/payload_guid.rs | 1 + 5 files changed, 137 insertions(+), 75 deletions(-) diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index ab4d1d3..2bfc457 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -1,14 +1,13 @@ use std::{ - ffi::{c_void, CStr}, + ffi::{c_void, CStr, CString, OsString}, mem, os::raw::c_ushort, ptr, }; use std::os::windows::io::AsRawHandle; -use wchar::wch_c; use winapi::{ - shared::minwindef::HMODULE, + shared::minwindef::{FARPROC, HMODULE}, um::debugapi::OutputDebugStringA, um::libloaderapi::{GetProcAddress, LoadLibraryW}, }; @@ -17,62 +16,76 @@ use crate::cuda::CUuuid; pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll"; const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0"; - +const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0"; +lazy_static! { + static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() }; +} include!("../../zluda_redirect/src/payload_guid.rs"); -pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void { - let load_lib = if is_detoured() { - match get_non_detoured_load_library() { - Some(load_lib) => load_lib, - None => return ptr::null_mut(), - } - } else { - LoadLibraryW - }; - let libcuda_path_uf16 = libcuda_path - .encode_utf16() - .chain(std::iter::once(0)) - .collect::>(); - load_lib(libcuda_path_uf16.as_ptr()) as *mut _ +#[allow(non_snake_case)] +struct PlatformLibrary { + LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE, + GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC, } -unsafe fn is_detoured() -> bool { - let mut module = ptr::null_mut(); - loop { - module = detours_sys::DetourEnumerateModules(module); - if module == ptr::null_mut() { - break; - } - let mut size = 0; - let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_NVCUDA_GUID, &mut size); - if payload != ptr::null_mut() { - return true; +impl PlatformLibrary { + #[allow(non_snake_case)] + unsafe fn new() -> Self { + let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() { + None => ( + LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE, + mem::transmute( + GetProcAddress + as unsafe extern "system" fn( + hModule: HMODULE, + lpProcName: *const i8, + ) -> FARPROC, + ), + ), + Some(zluda_with) => ( + mem::transmute(GetProcAddress( + zluda_with, + LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _, + )), + mem::transmute(GetProcAddress( + zluda_with, + GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _, + )), + ), + }; + PlatformLibrary { + LoadLibraryW, + GetProcAddress, } } - false -} -unsafe fn get_non_detoured_load_library( -) -> Option HMODULE> { - let mut module = ptr::null_mut(); - loop { - module = detours_sys::DetourEnumerateModules(module); - if module == ptr::null_mut() { - break; - } - let result = GetProcAddress( - module as *mut _, - LOAD_LIBRARY_NO_REDIRECT.as_ptr() as *mut _, - ); - if result != ptr::null_mut() { - return Some(mem::transmute(result)); + unsafe fn get_detourer_module() -> Option { + let mut module = ptr::null_mut(); + loop { + module = detours_sys::DetourEnumerateModules(module); + if module == ptr::null_mut() { + break; + } + let mut size = 0; + let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _); + if payload != ptr::null_mut() { + return Some(module as _); + } } + None } - None +} + +pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void { + let libcuda_path_uf16 = libcuda_path + .encode_utf16() + .chain(std::iter::once(0)) + .collect::>(); + (PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _ } pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void { - GetProcAddress(handle as *mut _, func.as_ptr()) as *mut _ + (PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _ } #[macro_export] diff --git a/zluda_inject/Cargo.toml b/zluda_inject/Cargo.toml index 1181a21..cc9dc9b 100644 --- a/zluda_inject/Cargo.toml +++ b/zluda_inject/Cargo.toml @@ -11,3 +11,9 @@ path = "src/main.rs" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] } detours-sys = { path = "../detours-sys" } + +[dev-dependencies] +# dependency for integration tests +zluda_redirect = { path = "../zluda_redirect" } +# dependency for integration tests +zluda_dump = { path = "../zluda_dump" } diff --git a/zluda_inject/tests/inject.rs b/zluda_inject/tests/inject.rs index 5a19d8a..de5fef8 100644 --- a/zluda_inject/tests/inject.rs +++ b/zluda_inject/tests/inject.rs @@ -3,8 +3,8 @@ use std::{env, io, path::PathBuf, process::Command}; #[test] fn direct_cuinit() -> io::Result<()> { let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with")); - let mut zluda_redirect_dll = zluda_with_exe.parent().unwrap().to_path_buf(); - zluda_redirect_dll.push("zluda_redirect.dll"); + let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf(); + zluda_dump_dll.push("zluda_dump.dll"); let helpers_dir = env!("HELPERS_OUT_DIR"); let exe_under_test = format!( "{}{}direct_cuinit.exe", @@ -12,11 +12,9 @@ fn direct_cuinit() -> io::Result<()> { std::path::MAIN_SEPARATOR ); let mut test_cmd = Command::new(&zluda_with_exe); - test_cmd - .arg(&zluda_redirect_dll) - .arg("--") - .arg(&exe_under_test); + let test_cmd = test_cmd.arg(&zluda_dump_dll).arg("--").arg(&exe_under_test); let test_output = test_cmd.output()?; + assert!(test_output.status.success()); let stderr_text = String::from_utf8(test_output.stderr).unwrap(); assert!(stderr_text.contains("ZLUDA_DUMP")); Ok(()) diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index d695ff7..f2d2739 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -4,6 +4,7 @@ extern crate detours_sys; extern crate winapi; use std::{ + collections::HashMap, ffi::{c_void, CStr}, mem, os::raw::c_uint, @@ -61,9 +62,13 @@ static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None; static mut ZLUDA_ML_PATH_UTF8: Vec = Vec::new(); static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None; static mut CURRENT_MODULE_FILENAME: Vec = Vec::new(); -static mut DETOUR_DETACH: Option = None; +static mut DETOUR_STATE: Option = None; const CUDA_ERROR_NOT_SUPPORTED: c_uint = 801; +#[no_mangle] +#[used] +pub static ZLUDA_REDIRECT: () = (); + static mut LOAD_LIBRARY_A: unsafe extern "system" fn(lpLibFileName: LPCSTR) -> HMODULE = LoadLibraryA; @@ -148,6 +153,24 @@ static mut CREATE_PROCESS_WITH_LOGON_W: unsafe extern "system" fn( lpProcessInformation: LPPROCESS_INFORMATION, ) -> BOOL = CreateProcessWithLogonW; +#[no_mangle] +#[allow(non_snake_case)] +unsafe extern "system" fn ZludaGetProcAddress_NoRedirect( + hModule: HMODULE, + lpProcName: LPCSTR, +) -> FARPROC { + if let Some(detour_guard) = &DETOUR_STATE { + if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule { + let proc_name = CStr::from_ptr(lpProcName); + return match detour_guard.overriden_cuda_fns.get(proc_name) { + Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn), + None => ptr::null_mut(), + }; + } + } + GetProcAddress(hModule, lpProcName) +} + #[no_mangle] #[allow(non_snake_case)] unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -> HMODULE { @@ -361,7 +384,9 @@ struct DetourDetachGuard { state: DetourUndoState, suspended_threads: Vec<*mut c_void>, // First element is the original fn, second is the new fn - overriden_functions: Vec<(*mut c_void, *mut c_void)>, + overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>, + nvcuda_module: HMODULE, + overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>, } impl DetourDetachGuard { @@ -371,12 +396,16 @@ impl DetourDetachGuard { // also get overriden, so for example ZludaLoadLibraryExW ends calling // itself recursively until stack overflow exception occurs unsafe fn detour_functions<'a>( - override_fn_pairs: Vec<(*mut c_void, *mut c_void)>, + nvcuda_module: HMODULE, + non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>, + cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>, ) -> Option { let mut result = DetourDetachGuard { state: DetourUndoState::DoNothing, suspended_threads: Vec::new(), - overriden_functions: override_fn_pairs, + overriden_non_cuda_fns: non_cuda_fns, + nvcuda_module, + overriden_cuda_fns: cuda_fns, }; if DetourTransactionBegin() != NO_ERROR as i32 { return None; @@ -390,24 +419,35 @@ impl DetourDetachGuard { return None; } } - result.overriden_functions.extend_from_slice(&[ - (CREATE_PROCESS_A as _, ZludaCreateProcessA as _), - (CREATE_PROCESS_W as _, ZludaCreateProcessW as _), + result.overriden_non_cuda_fns.extend_from_slice(&[ ( - CREATE_PROCESS_AS_USER_W as _, + &mut CREATE_PROCESS_A as *mut _ as _, + ZludaCreateProcessA as _, + ), + ( + &mut CREATE_PROCESS_W as *mut _ as _, + ZludaCreateProcessW as _, + ), + ( + &mut CREATE_PROCESS_AS_USER_W as *mut _ as _, ZludaCreateProcessAsUserW as _, ), ( - CREATE_PROCESS_WITH_LOGON_W as _, + &mut CREATE_PROCESS_WITH_LOGON_W as *mut _ as _, ZludaCreateProcessWithLogonW as _, ), ( - CREATE_PROCESS_WITH_TOKEN_W as _, + &mut CREATE_PROCESS_WITH_TOKEN_W as *mut _ as _, ZludaCreateProcessWithTokenW as _, ), ]); - for (original_fn, new_fn) in result.overriden_functions.iter_mut() { - if DetourAttach(original_fn as *mut _, *new_fn) != NO_ERROR as i32 { + for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain( + result + .overriden_cuda_fns + .values_mut() + .map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)), + ) { + if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 { return None; } } @@ -633,13 +673,13 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u }; match detach_guard { Some(g) => { - DETOUR_DETACH = Some(g); + DETOUR_STATE = Some(g); TRUE } None => FALSE, } } else if dwReason == DLL_PROCESS_DETACH { - match DETOUR_DETACH.take() { + match DETOUR_STATE.take() { Some(_) => TRUE, None => FALSE, } @@ -691,9 +731,9 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option { } let original_functions = gather_imports(nvcuda_mod); let override_functions = gather_imports(zluda_module); - let mut override_fn_pairs = Vec::with_capacity(original_functions.len()); + let mut override_fn_pairs = HashMap::with_capacity(original_functions.len()); // TODO: optimize - for (original_fn_name, mut original_fn_address) in original_functions { + for (original_fn_name, original_fn_address) in original_functions { let override_fn_address = match override_functions.binary_search_by_key(&original_fn_name, |(name, _)| *name) { Ok(x) => override_functions[x].1, @@ -702,9 +742,12 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option { cuda_unsupported as _ } }; - override_fn_pairs.push((original_fn_address as _, override_fn_address)); + override_fn_pairs.insert( + original_fn_name, + (original_fn_address as _, override_fn_address), + ); } - DetourDetachGuard::detour_functions(override_fn_pairs) + DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs) } unsafe extern "system" fn cuda_unsupported() -> c_uint { @@ -735,7 +778,10 @@ unsafe extern "stdcall" fn gather_imports_impl( #[must_use] unsafe fn attach_load_libary() -> Option { let detour_functions = vec![ - (&mut LOAD_LIBRARY_A as *mut _ as _, ZludaLoadLibraryA as _), + ( + &mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, + ZludaLoadLibraryA as *mut c_void, + ), (&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _), ( &mut LOAD_LIBRARY_EX_A as *mut _ as _, @@ -746,9 +792,7 @@ unsafe fn attach_load_libary() -> Option { ZludaLoadLibraryExW as _, ), ]; - let result = DetourDetachGuard::detour_functions(detour_functions); - - result + DetourDetachGuard::detour_functions(ptr::null_mut(), detour_functions, HashMap::new()) } fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> { diff --git a/zluda_redirect/src/payload_guid.rs b/zluda_redirect/src/payload_guid.rs index 2d7ee6c..968e244 100644 --- a/zluda_redirect/src/payload_guid.rs +++ b/zluda_redirect/src/payload_guid.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] const PAYLOAD_NVCUDA_GUID: detours_sys::GUID = detours_sys::GUID { Data1: 0xC225FC0C, Data2: 0x00D7, -- cgit v1.2.3