diff options
author | Andrzej Janik <[email protected]> | 2021-12-01 23:08:07 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2021-12-01 23:08:07 +0100 |
commit | 400feaf015fc0084608479df58d6cdeccb87986b (patch) | |
tree | 3434eb1974bb6135b8e58e232235a92163477bcd /zluda_dump/src/os_win.rs | |
parent | fd1c13560f29e9f6e43d19b5cbe48dcd1351bcd6 (diff) | |
download | ZLUDA-400feaf015fc0084608479df58d6cdeccb87986b.tar.gz ZLUDA-400feaf015fc0084608479df58d6cdeccb87986b.zip |
Add test for injecting app that directly uses nvcuda
Diffstat (limited to 'zluda_dump/src/os_win.rs')
-rw-r--r-- | zluda_dump/src/os_win.rs | 107 |
1 files changed, 60 insertions, 47 deletions
diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index ab4d1d3..2bfc457 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -1,14 +1,13 @@ use std::{
- ffi::{c_void, CStr},
+ ffi::{c_void, CStr, CString, OsString},
mem,
os::raw::c_ushort,
ptr,
};
use std::os::windows::io::AsRawHandle;
-use wchar::wch_c;
use winapi::{
- shared::minwindef::HMODULE,
+ shared::minwindef::{FARPROC, HMODULE},
um::debugapi::OutputDebugStringA,
um::libloaderapi::{GetProcAddress, LoadLibraryW},
};
@@ -17,62 +16,76 @@ use crate::cuda::CUuuid; pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = "C:\\Windows\\System32\\nvcuda.dll";
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
-
+const GET_PROC_ADDRESS_NO_REDIRECT: &'static [u8] = b"ZludaGetProcAddress_NoRedirect\0";
+lazy_static! {
+ static ref PLATFORM_LIBRARY: PlatformLibrary = unsafe { PlatformLibrary::new() };
+}
include!("../../zluda_redirect/src/payload_guid.rs");
-pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
- let load_lib = if is_detoured() {
- match get_non_detoured_load_library() {
- Some(load_lib) => load_lib,
- None => return ptr::null_mut(),
- }
- } else {
- LoadLibraryW
- };
- let libcuda_path_uf16 = libcuda_path
- .encode_utf16()
- .chain(std::iter::once(0))
- .collect::<Vec<_>>();
- load_lib(libcuda_path_uf16.as_ptr()) as *mut _
+#[allow(non_snake_case)]
+struct PlatformLibrary {
+ LoadLibraryW: unsafe extern "system" fn(*const u16) -> HMODULE,
+ GetProcAddress: unsafe extern "system" fn(hModule: HMODULE, lpProcName: *const u8) -> FARPROC,
}
-unsafe fn is_detoured() -> bool {
- let mut module = ptr::null_mut();
- loop {
- module = detours_sys::DetourEnumerateModules(module);
- if module == ptr::null_mut() {
- break;
- }
- let mut size = 0;
- let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_NVCUDA_GUID, &mut size);
- if payload != ptr::null_mut() {
- return true;
+impl PlatformLibrary {
+ #[allow(non_snake_case)]
+ unsafe fn new() -> Self {
+ let (LoadLibraryW, GetProcAddress) = match Self::get_detourer_module() {
+ None => (
+ LoadLibraryW as unsafe extern "system" fn(*const u16) -> HMODULE,
+ mem::transmute(
+ GetProcAddress
+ as unsafe extern "system" fn(
+ hModule: HMODULE,
+ lpProcName: *const i8,
+ ) -> FARPROC,
+ ),
+ ),
+ Some(zluda_with) => (
+ mem::transmute(GetProcAddress(
+ zluda_with,
+ LOAD_LIBRARY_NO_REDIRECT.as_ptr() as _,
+ )),
+ mem::transmute(GetProcAddress(
+ zluda_with,
+ GET_PROC_ADDRESS_NO_REDIRECT.as_ptr() as _,
+ )),
+ ),
+ };
+ PlatformLibrary {
+ LoadLibraryW,
+ GetProcAddress,
}
}
- false
-}
-unsafe fn get_non_detoured_load_library(
-) -> Option<unsafe extern "system" fn(*const c_ushort) -> HMODULE> {
- let mut module = ptr::null_mut();
- loop {
- module = detours_sys::DetourEnumerateModules(module);
- if module == ptr::null_mut() {
- break;
- }
- let result = GetProcAddress(
- module as *mut _,
- LOAD_LIBRARY_NO_REDIRECT.as_ptr() as *mut _,
- );
- if result != ptr::null_mut() {
- return Some(mem::transmute(result));
+ unsafe fn get_detourer_module() -> Option<HMODULE> {
+ let mut module = ptr::null_mut();
+ loop {
+ module = detours_sys::DetourEnumerateModules(module);
+ if module == ptr::null_mut() {
+ break;
+ }
+ let mut size = 0;
+ let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
+ if payload != ptr::null_mut() {
+ return Some(module as _);
+ }
}
+ None
}
- None
+}
+
+pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
+ let libcuda_path_uf16 = libcuda_path
+ .encode_utf16()
+ .chain(std::iter::once(0))
+ .collect::<Vec<_>>();
+ (PLATFORM_LIBRARY.LoadLibraryW)(libcuda_path_uf16.as_ptr()) as _
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
- GetProcAddress(handle as *mut _, func.as_ptr()) as *mut _
+ (PLATFORM_LIBRARY.GetProcAddress)(handle as _, func.as_ptr() as _) as _
}
#[macro_export]
|