aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_redirect
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-02-28 01:50:04 +0100
committerAndrzej Janik <[email protected]>2021-02-28 01:50:04 +0100
commitba83bb28f72faee394a3ecd583eb1cff7a41516d (patch)
tree1e67dd333cfb1bb125b6d63958e4133fec3d4871 /zluda_redirect
parentb7ee6d66c3cfac9922addc6c362ae437cfa8fee5 (diff)
downloadZLUDA-ba83bb28f72faee394a3ecd583eb1cff7a41516d.tar.gz
ZLUDA-ba83bb28f72faee394a3ecd583eb1cff7a41516d.zip
Inject our own NVML
Diffstat (limited to 'zluda_redirect')
-rw-r--r--zluda_redirect/src/lib.rs88
-rw-r--r--zluda_redirect/src/payload_guid.rs10
2 files changed, 78 insertions, 20 deletions
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs
index 5de7530..bfd8200 100644
--- a/zluda_redirect/src/lib.rs
+++ b/zluda_redirect/src/lib.rs
@@ -55,8 +55,12 @@ include!("payload_guid.rs");
const NVCUDA_UTF8: &'static str = "NVCUDA.DLL";
const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL");
+const NVML_UTF8: &'static str = "NVML.DLL";
+const NVML_UTF16: &[u16] = wch!("NVML.DLL");
static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new();
static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None;
+static mut ZLUDA_ML_PATH_UTF8: Vec<u8> = Vec::new();
+static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None;
static mut DETACH_LOAD_LIBRARY: bool = false;
static mut NVCUDA_ORIGINAL_MODULE: HMODULE = ptr::null_mut();
static mut CUINIT_ORIGINAL_FN: FARPROC = ptr::null_mut();
@@ -158,6 +162,8 @@ unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -
unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
ZLUDA_PATH_UTF8.as_ptr() as *const _
+ } else if is_nvml_dll_utf8(lpLibFileName as *const _) {
+ ZLUDA_ML_PATH_UTF8.as_ptr() as *const _
} else {
lpLibFileName
};
@@ -168,6 +174,8 @@ unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE {
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
ZLUDA_PATH_UTF16.unwrap().as_ptr()
+ } else if is_nvml_dll_utf16(lpLibFileName as *const _) {
+ ZLUDA_ML_PATH_UTF16.unwrap().as_ptr()
} else {
lpLibFileName
};
@@ -182,6 +190,8 @@ unsafe extern "system" fn ZludaLoadLibraryExA(
) -> HMODULE {
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
ZLUDA_PATH_UTF8.as_ptr() as *const _
+ } else if is_nvml_dll_utf8(lpLibFileName as *const _) {
+ ZLUDA_ML_PATH_UTF8.as_ptr() as *const _
} else {
lpLibFileName
};
@@ -196,6 +206,8 @@ unsafe extern "system" fn ZludaLoadLibraryExW(
) -> HMODULE {
let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) {
ZLUDA_PATH_UTF16.unwrap().as_ptr()
+ } else if is_nvml_dll_utf16(lpLibFileName as *const _) {
+ ZLUDA_ML_PATH_UTF16.unwrap().as_ptr()
} else {
lpLibFileName
};
@@ -363,7 +375,7 @@ unsafe fn continue_create_process_hook(
}
if detours_sys::DetourCopyPayloadToProcess(
(*process_information).hProcess,
- &PAYLOAD_GUID,
+ &PAYLOAD_NVCUDA_GUID,
ZLUDA_PATH_UTF16.unwrap().as_ptr() as *mut _,
(ZLUDA_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32,
) == FALSE
@@ -372,6 +384,16 @@ unsafe fn continue_create_process_hook(
return 0;
}
+ if detours_sys::DetourCopyPayloadToProcess(
+ (*process_information).hProcess,
+ &PAYLOAD_NVML_GUID,
+ ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() as *mut _,
+ (ZLUDA_ML_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32,
+ ) == FALSE
+ {
+ TerminateProcess((*process_information).hProcess, 1);
+ return 0;
+ }
if creation_flags & CREATE_SUSPENDED == 0 {
if ResumeThread((*process_information).hThread) == -1i32 as u32 {
TerminateProcess((*process_information).hProcess, 1);
@@ -490,7 +512,23 @@ unsafe extern "C" fn unsupported_cuda_fn() -> c_uint {
}
fn is_nvcuda_dll_utf8(lib: *const u8) -> bool {
- is_nvcuda_dll(lib, 0, NVCUDA_UTF8.as_bytes(), |c| {
+ is_dll_utf8(lib, NVCUDA_UTF8.as_bytes())
+}
+
+fn is_nvcuda_dll_utf16(lib: *const u16) -> bool {
+ is_dll_utf16(lib, NVCUDA_UTF16)
+}
+
+fn is_nvml_dll_utf8(lib: *const u8) -> bool {
+ is_dll_utf8(lib, NVML_UTF8.as_bytes())
+}
+
+fn is_nvml_dll_utf16(lib: *const u16) -> bool {
+ is_dll_utf16(lib, NVML_UTF16)
+}
+
+fn is_dll_utf8(lib: *const u8, name: &[u8]) -> bool {
+ is_dll_impl(lib, 0, name, |c| {
if c >= 'a' as u8 && c <= 'z' as u8 {
c - 32
} else {
@@ -498,8 +536,9 @@ fn is_nvcuda_dll_utf8(lib: *const u8) -> bool {
}
})
}
-fn is_nvcuda_dll_utf16(lib: *const u16) -> bool {
- is_nvcuda_dll(lib, 0u16, NVCUDA_UTF16, |c| {
+
+fn is_dll_utf16(lib: *const u16, name: &[u16]) -> bool {
+ is_dll_impl(lib, 0u16, name, |c| {
if c >= 'a' as u16 && c <= 'z' as u16 {
c - 32
} else {
@@ -508,7 +547,7 @@ fn is_nvcuda_dll_utf16(lib: *const u16) -> bool {
})
}
-fn is_nvcuda_dll<T: Copy + PartialEq>(
+fn is_dll_impl<T: Copy + PartialEq>(
lib: *const T,
zero: T,
dll_name: &[T],
@@ -544,11 +583,13 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
if !initialize_current_module_name(instDLL) {
return FALSE;
}
- match get_zluda_dll_path() {
- Some(path) => {
- ZLUDA_PATH_UTF16 = Some(path);
+ match get_zluda_dlls_paths() {
+ Some((nvcuda_path, nvml_path)) => {
+ ZLUDA_PATH_UTF16 = Some(nvcuda_path);
+ ZLUDA_ML_PATH_UTF16 = Some(nvml_path);
// from_utf16_lossy(...) handles terminating NULL correctly
- ZLUDA_PATH_UTF8 = String::from_utf16_lossy(path).into_bytes();
+ ZLUDA_PATH_UTF8 = String::from_utf16_lossy(nvcuda_path).into_bytes();
+ ZLUDA_ML_PATH_UTF8 = String::from_utf16_lossy(nvml_path).into_bytes();
}
None => return FALSE,
}
@@ -740,25 +781,34 @@ unsafe fn detach_load_library() -> i32 {
TRUE
}
-fn get_zluda_dll_path() -> Option<&'static [u16]> {
+fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> {
+ match get_payload(&PAYLOAD_NVCUDA_GUID) {
+ None => None,
+ Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
+ None => return None,
+ Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)),
+ },
+ }
+}
+
+fn get_payload(guid: &detours_sys::GUID) -> Option<&'static [u16]> {
let mut module = ptr::null_mut();
loop {
module = unsafe { detours_sys::DetourEnumerateModules(module) };
if module == ptr::null_mut() {
- break;
+ return None;
}
let mut size = 0;
- let payload = unsafe { detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size) };
- if payload != ptr::null_mut() {
- return unsafe {
- Some(slice::from_raw_parts(
- payload as *const _,
+ let payload_ptr = unsafe { detours_sys::DetourFindPayload(module, guid, &mut size) };
+ if payload_ptr != ptr::null_mut() {
+ return Some(unsafe {
+ slice::from_raw_parts(
+ payload_ptr as *const _,
(size as usize) / mem::size_of::<u16>(),
- ))
- };
+ )
+ });
}
}
- None
}
#[must_use]
diff --git a/zluda_redirect/src/payload_guid.rs b/zluda_redirect/src/payload_guid.rs
index eaf021d..2d7ee6c 100644
--- a/zluda_redirect/src/payload_guid.rs
+++ b/zluda_redirect/src/payload_guid.rs
@@ -1,6 +1,14 @@
-const PAYLOAD_GUID: detours_sys::GUID = detours_sys::GUID {
+const PAYLOAD_NVCUDA_GUID: detours_sys::GUID = detours_sys::GUID {
Data1: 0xC225FC0C,
Data2: 0x00D7,
Data3: 0x40B8,
Data4: [0x93, 0x5A, 0x7E, 0x34, 0x2A, 0x93, 0x44, 0xC1],
+};
+
+#[allow(dead_code)]
+const PAYLOAD_NVML_GUID: detours_sys::GUID = detours_sys::GUID {
+ Data1: 0x75B54759,
+ Data2: 0xB6F1,
+ Data3: 0x49C2,
+ Data4: [0xA2, 0x09, 0x68, 0x54, 0x96, 0xBD, 0x70, 0xC0],
}; \ No newline at end of file