diff options
-rw-r--r-- | zluda_dump/src/os_win.rs | 2 | ||||
-rw-r--r-- | zluda_inject/src/bin.rs | 57 | ||||
-rw-r--r-- | zluda_ml/README | 2 | ||||
-rw-r--r-- | zluda_ml/src/impl.rs | 2 | ||||
-rw-r--r-- | zluda_redirect/src/lib.rs | 88 | ||||
-rw-r--r-- | zluda_redirect/src/payload_guid.rs | 10 |
6 files changed, 121 insertions, 40 deletions
diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index 7f985c5..7e411ac 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -36,7 +36,7 @@ unsafe fn is_detoured() -> bool { break;
}
let mut size = 0;
- let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size);
+ let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_NVCUDA_GUID, &mut size);
if payload != ptr::null_mut() {
return true;
}
diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs index ce83fe9..42ae748 100644 --- a/zluda_inject/src/bin.rs +++ b/zluda_inject/src/bin.rs @@ -1,8 +1,8 @@ -use std::mem;
use std::path::Path;
use std::ptr;
use std::{env, ops::Deref};
use std::{error::Error, process};
+use std::{mem, path::PathBuf};
use mem::size_of_val;
use winapi::um::{
@@ -20,6 +20,7 @@ use winapi::um::winbase::{INFINITE, WAIT_FAILED}; static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
static ZLUDA_DLL: &'static str = "nvcuda.dll";
+static ZLUDA_ML_DLL: &'static str = "nvml.dll";
include!("../../zluda_redirect/src/payload_guid.rs");
@@ -31,7 +32,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> { let injector_path = env::current_exe()?;
let injector_dir = injector_path.parent().unwrap();
let redirect_path = create_redirect_path(injector_dir);
- let (mut inject_path, cmd) = create_inject_path(&args[1..], injector_dir);
+ let (mut inject_nvcuda_path, mut inject_nvml_path, cmd) =
+ create_inject_path(&args[1..], injector_dir);
let mut cmd_line = construct_command_line(cmd);
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
@@ -56,9 +58,18 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> { os_call!(
detours_sys::DetourCopyPayloadToProcess(
proc_info.hProcess,
- &PAYLOAD_GUID,
- inject_path.as_mut_ptr() as *mut _,
- (inject_path.len() * mem::size_of::<u16>()) as u32
+ &PAYLOAD_NVCUDA_GUID,
+ inject_nvcuda_path.as_mut_ptr() as *mut _,
+ (inject_nvcuda_path.len() * mem::size_of::<u16>()) as u32
+ ),
+ |x| x != 0
+ );
+ os_call!(
+ detours_sys::DetourCopyPayloadToProcess(
+ proc_info.hProcess,
+ &PAYLOAD_NVML_GUID,
+ inject_nvml_path.as_mut_ptr() as *mut _,
+ (inject_nvml_path.len() * mem::size_of::<u16>()) as u32
),
|x| x != 0
);
@@ -173,22 +184,34 @@ fn create_redirect_path(injector_dir: &Path) -> Vec<u8> { result
}
-fn create_inject_path<'a>(args: &'a [String], injector_dir: &Path) -> (Vec<u16>, &'a [String]) {
- if args.get(0).map(Deref::deref) == Some("--") {
- let mut injector_dir = injector_dir.to_path_buf();
- injector_dir.push(ZLUDA_DLL);
- let mut result = injector_dir
- .to_string_lossy()
- .as_ref()
- .encode_utf16()
- .collect::<Vec<_>>();
- result.push(0);
- (result, &args[1..])
+fn create_inject_path<'a>(
+ args: &'a [String],
+ injector_dir: &Path,
+) -> (Vec<u16>, Vec<u16>, &'a [String]) {
+ let injector_dir = injector_dir.to_path_buf();
+ let (nvcuda_path, unparsed_args) = if args.get(0).map(Deref::deref) == Some("--") {
+ (
+ encode_file_in_directory_raw(injector_dir.clone(), ZLUDA_DLL),
+ &args[1..],
+ )
} else if args.get(1).map(Deref::deref) == Some("--") {
let mut dll_path = args[0].encode_utf16().collect::<Vec<_>>();
dll_path.push(0);
(dll_path, &args[2..])
} else {
print_help_and_exit()
- }
+ };
+ let nvml_path = encode_file_in_directory_raw(injector_dir, ZLUDA_ML_DLL);
+ (nvcuda_path, nvml_path, unparsed_args)
+}
+
+fn encode_file_in_directory_raw(mut dir: PathBuf, file: &'static str) -> Vec<u16> {
+ dir.push(file);
+ let mut result = dir
+ .to_string_lossy()
+ .as_ref()
+ .encode_utf16()
+ .collect::<Vec<_>>();
+ result.push(0);
+ result
}
diff --git a/zluda_ml/README b/zluda_ml/README index 300a2a3..60a59ad 100644 --- a/zluda_ml/README +++ b/zluda_ml/README @@ -1,3 +1,3 @@ bindgen "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\include\nvml.h" --whitelist-function="^nvml.*" --size_t-is-usize --default-enum-style=newtype --no-layout-tests --no-doc-comments --no-derive-debug -o src/nvml.rs
-sed -i -e 's/extern "C" {//g' -e 's/-> nvmlReturn_t;/-> nvmlReturn_t { impl_::unsupported()/g' -e 's/pub fn /#[no_mangle] pub extern "C" fn /g' src/nvml.rs
+sed -i -e 's/extern "C" {//g' -e 's/-> nvmlReturn_t;/-> nvmlReturn_t { crate::r#impl::unimplemented()/g' -e 's/pub fn /#[no_mangle] pub extern "C" fn /g' src/nvml.rs
rustfmt src/nvml.rs
\ No newline at end of file diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs index 0fd0de1..75f3ca2 100644 --- a/zluda_ml/src/impl.rs +++ b/zluda_ml/src/impl.rs @@ -1,5 +1,5 @@ use level_zero as l0;
-use std::{io::Write, ops::Add};
+use std::io::Write;
use std::{
os::raw::{c_char, c_uint},
ptr,
diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index 5de7530..bfd8200 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -55,8 +55,12 @@ include!("payload_guid.rs"); const NVCUDA_UTF8: &'static str = "NVCUDA.DLL"; const NVCUDA_UTF16: &[u16] = wch!("NVCUDA.DLL"); +const NVML_UTF8: &'static str = "NVML.DLL"; +const NVML_UTF16: &[u16] = wch!("NVML.DLL"); static mut ZLUDA_PATH_UTF8: Vec<u8> = Vec::new(); static mut ZLUDA_PATH_UTF16: Option<&'static [u16]> = None; +static mut ZLUDA_ML_PATH_UTF8: Vec<u8> = Vec::new(); +static mut ZLUDA_ML_PATH_UTF16: Option<&'static [u16]> = None; static mut DETACH_LOAD_LIBRARY: bool = false; static mut NVCUDA_ORIGINAL_MODULE: HMODULE = ptr::null_mut(); static mut CUINIT_ORIGINAL_FN: FARPROC = ptr::null_mut(); @@ -158,6 +162,8 @@ unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) - unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE { let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) { ZLUDA_PATH_UTF8.as_ptr() as *const _ + } else if is_nvml_dll_utf8(lpLibFileName as *const _) { + ZLUDA_ML_PATH_UTF8.as_ptr() as *const _ } else { lpLibFileName }; @@ -168,6 +174,8 @@ unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE { unsafe extern "system" fn ZludaLoadLibraryW(lpLibFileName: LPCWSTR) -> HMODULE { let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) { ZLUDA_PATH_UTF16.unwrap().as_ptr() + } else if is_nvml_dll_utf16(lpLibFileName as *const _) { + ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() } else { lpLibFileName }; @@ -182,6 +190,8 @@ unsafe extern "system" fn ZludaLoadLibraryExA( ) -> HMODULE { let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) { ZLUDA_PATH_UTF8.as_ptr() as *const _ + } else if is_nvml_dll_utf8(lpLibFileName as *const _) { + ZLUDA_ML_PATH_UTF8.as_ptr() as *const _ } else { lpLibFileName }; @@ -196,6 +206,8 @@ unsafe extern "system" fn ZludaLoadLibraryExW( ) -> HMODULE { let nvcuda_file_name = if is_nvcuda_dll_utf16(lpLibFileName) { ZLUDA_PATH_UTF16.unwrap().as_ptr() + } else if is_nvml_dll_utf16(lpLibFileName as *const _) { + ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() } else { lpLibFileName }; @@ -363,7 +375,7 @@ unsafe fn continue_create_process_hook( } if detours_sys::DetourCopyPayloadToProcess( (*process_information).hProcess, - &PAYLOAD_GUID, + &PAYLOAD_NVCUDA_GUID, ZLUDA_PATH_UTF16.unwrap().as_ptr() as *mut _, (ZLUDA_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32, ) == FALSE @@ -372,6 +384,16 @@ unsafe fn continue_create_process_hook( return 0; } + if detours_sys::DetourCopyPayloadToProcess( + (*process_information).hProcess, + &PAYLOAD_NVML_GUID, + ZLUDA_ML_PATH_UTF16.unwrap().as_ptr() as *mut _, + (ZLUDA_ML_PATH_UTF16.unwrap().len() * mem::size_of::<u16>()) as u32, + ) == FALSE + { + TerminateProcess((*process_information).hProcess, 1); + return 0; + } if creation_flags & CREATE_SUSPENDED == 0 { if ResumeThread((*process_information).hThread) == -1i32 as u32 { TerminateProcess((*process_information).hProcess, 1); @@ -490,7 +512,23 @@ unsafe extern "C" fn unsupported_cuda_fn() -> c_uint { } fn is_nvcuda_dll_utf8(lib: *const u8) -> bool { - is_nvcuda_dll(lib, 0, NVCUDA_UTF8.as_bytes(), |c| { + is_dll_utf8(lib, NVCUDA_UTF8.as_bytes()) +} + +fn is_nvcuda_dll_utf16(lib: *const u16) -> bool { + is_dll_utf16(lib, NVCUDA_UTF16) +} + +fn is_nvml_dll_utf8(lib: *const u8) -> bool { + is_dll_utf8(lib, NVML_UTF8.as_bytes()) +} + +fn is_nvml_dll_utf16(lib: *const u16) -> bool { + is_dll_utf16(lib, NVML_UTF16) +} + +fn is_dll_utf8(lib: *const u8, name: &[u8]) -> bool { + is_dll_impl(lib, 0, name, |c| { if c >= 'a' as u8 && c <= 'z' as u8 { c - 32 } else { @@ -498,8 +536,9 @@ fn is_nvcuda_dll_utf8(lib: *const u8) -> bool { } }) } -fn is_nvcuda_dll_utf16(lib: *const u16) -> bool { - is_nvcuda_dll(lib, 0u16, NVCUDA_UTF16, |c| { + +fn is_dll_utf16(lib: *const u16, name: &[u16]) -> bool { + is_dll_impl(lib, 0u16, name, |c| { if c >= 'a' as u16 && c <= 'z' as u16 { c - 32 } else { @@ -508,7 +547,7 @@ fn is_nvcuda_dll_utf16(lib: *const u16) -> bool { }) } -fn is_nvcuda_dll<T: Copy + PartialEq>( +fn is_dll_impl<T: Copy + PartialEq>( lib: *const T, zero: T, dll_name: &[T], @@ -544,11 +583,13 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u if !initialize_current_module_name(instDLL) { return FALSE; } - match get_zluda_dll_path() { - Some(path) => { - ZLUDA_PATH_UTF16 = Some(path); + match get_zluda_dlls_paths() { + Some((nvcuda_path, nvml_path)) => { + ZLUDA_PATH_UTF16 = Some(nvcuda_path); + ZLUDA_ML_PATH_UTF16 = Some(nvml_path); // from_utf16_lossy(...) handles terminating NULL correctly - ZLUDA_PATH_UTF8 = String::from_utf16_lossy(path).into_bytes(); + ZLUDA_PATH_UTF8 = String::from_utf16_lossy(nvcuda_path).into_bytes(); + ZLUDA_ML_PATH_UTF8 = String::from_utf16_lossy(nvml_path).into_bytes(); } None => return FALSE, } @@ -740,25 +781,34 @@ unsafe fn detach_load_library() -> i32 { TRUE } -fn get_zluda_dll_path() -> Option<&'static [u16]> { +fn get_zluda_dlls_paths() -> Option<(&'static [u16], &'static [u16])> { + match get_payload(&PAYLOAD_NVCUDA_GUID) { + None => None, + Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) { + None => return None, + Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)), + }, + } +} + +fn get_payload(guid: &detours_sys::GUID) -> Option<&'static [u16]> { let mut module = ptr::null_mut(); loop { module = unsafe { detours_sys::DetourEnumerateModules(module) }; if module == ptr::null_mut() { - break; + return None; } let mut size = 0; - let payload = unsafe { detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size) }; - if payload != ptr::null_mut() { - return unsafe { - Some(slice::from_raw_parts( - payload as *const _, + let payload_ptr = unsafe { detours_sys::DetourFindPayload(module, guid, &mut size) }; + if payload_ptr != ptr::null_mut() { + return Some(unsafe { + slice::from_raw_parts( + payload_ptr as *const _, (size as usize) / mem::size_of::<u16>(), - )) - }; + ) + }); } } - None } #[must_use] diff --git a/zluda_redirect/src/payload_guid.rs b/zluda_redirect/src/payload_guid.rs index eaf021d..2d7ee6c 100644 --- a/zluda_redirect/src/payload_guid.rs +++ b/zluda_redirect/src/payload_guid.rs @@ -1,6 +1,14 @@ -const PAYLOAD_GUID: detours_sys::GUID = detours_sys::GUID {
+const PAYLOAD_NVCUDA_GUID: detours_sys::GUID = detours_sys::GUID {
Data1: 0xC225FC0C,
Data2: 0x00D7,
Data3: 0x40B8,
Data4: [0x93, 0x5A, 0x7E, 0x34, 0x2A, 0x93, 0x44, 0xC1],
+};
+
+#[allow(dead_code)]
+const PAYLOAD_NVML_GUID: detours_sys::GUID = detours_sys::GUID {
+ Data1: 0x75B54759,
+ Data2: 0xB6F1,
+ Data3: 0x49C2,
+ Data4: [0xA2, 0x09, 0x68, 0x54, 0x96, 0xBD, 0x70, 0xC0],
};
\ No newline at end of file |