diff options
-rw-r--r-- | cuda_base/src/lib.rs | 69 | ||||
-rw-r--r-- | zluda_dump/src/format.rs | 1 | ||||
-rw-r--r-- | zluda_dump/src/lib.rs | 135 | ||||
-rw-r--r-- | zluda_dump/src/os_unix.rs | 2 | ||||
-rw-r--r-- | zluda_dump/src/os_win.rs | 2 | ||||
-rw-r--r-- | zluda_dump/src/side_by_side.rs | 77 |
6 files changed, 205 insertions, 81 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 8b804d1..3f6f779 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -9,12 +9,11 @@ use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; -use syn::token::Brace; use syn::visit_mut::VisitMut; use syn::{ bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident, - Item, ItemForeignMod, ItemMacro, LitStr, Macro, MacroDelimiter, PatType, Path, PathArguments, - PathSegment, ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, + Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, ReturnType, Signature, + Token, Type, TypeArray, TypePath, TypePtr, }; const CUDA_RS: &'static str = include_str! {"cuda.rs"}; @@ -109,8 +108,11 @@ impl VisitMut for FixAbi { // Then macro goes through every function in rust.rs, and for every fn `foo`: // * if `foo` is contained in `override_fns` then pass it into `override_macro` // * if `foo` is not contained in `override_fns` pass it to `normal_macro` -// Both `override_macro` and `normal_macro` expect this format: -// macro_foo!("system" fn cuCtxDetach(ctx: CUcontext) -> CUresult) +// Both `override_macro` and `normal_macro` expect semicolon-separated list: +// macro_foo!( +// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult; +// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult +// ) // Additionally, it does a fixup of CUDA types so they get prefixed with `type_path` #[proc_macro] pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { @@ -121,7 +123,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { .iter() .map(ToString::to_string) .collect::<FxHashSet<_>>(); - cuda_module + let (normal_macro_args, override_macro_args): (Vec<_>, Vec<_>) = cuda_module .items .into_iter() .filter_map(|item| match item { @@ -136,12 +138,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { }, .. }) => { - let path = if override_fns.contains(&ident.to_string()) { - &input.override_macro - } else { - &input.normal_macro - } - .clone(); + let use_normal_macro = !override_fns.contains(&ident.to_string()); let inputs = inputs .into_iter() .map(|fn_arg| match fn_arg { @@ -158,30 +155,42 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { ReturnType::Default => unreachable!(), }; let type_path = input.type_path.clone(); - let tokens = quote! { - "system" fn #ident(#inputs) -> #type_path :: #output - }; - Some(Item::Macro(ItemMacro { - attrs: Vec::new(), - ident: None, - mac: Macro { - path, - bang_token: Token![!](Span::call_site()), - delimiter: MacroDelimiter::Brace(Brace { - span: Span::call_site(), - }), - tokens, + Some(( + quote! { + "system" fn #ident(#inputs) -> #type_path :: #output }, - semi_token: None, - })) + use_normal_macro, + )) } _ => unreachable!(), }, _ => None, }) - .map(Item::into_token_stream) - .collect::<proc_macro2::TokenStream>() - .into() + .partition(|(_, use_normal_macro)| *use_normal_macro); + let mut result = proc_macro2::TokenStream::new(); + if !normal_macro_args.is_empty() { + let punctuated_normal_macro_args = to_punctuated::<Token![;]>(normal_macro_args); + let macro_ = &input.normal_macro; + result.extend(iter::once(quote! { + #macro_ ! (#punctuated_normal_macro_args); + })); + } + if !override_macro_args.is_empty() { + let punctuated_override_macro_args = to_punctuated::<Token![;]>(override_macro_args); + let macro_ = &input.override_macro; + result.extend(iter::once(quote! { + #macro_ ! (#punctuated_override_macro_args); + })); + } + result.into() +} + +fn to_punctuated<P: ToTokens + Default>( + elms: Vec<(proc_macro2::TokenStream, bool)>, +) -> proc_macro2::TokenStream { + let mut collection = Punctuated::<proc_macro2::TokenStream, P>::new(); + collection.extend(elms.into_iter().map(|(token_stream, _)| token_stream)); + collection.into_token_stream() } fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> { diff --git a/zluda_dump/src/format.rs b/zluda_dump/src/format.rs index 8080fbc..380e52d 100644 --- a/zluda_dump/src/format.rs +++ b/zluda_dump/src/format.rs @@ -1,4 +1,3 @@ -extern crate cuda_types;
use std::{
ffi::{c_void, CStr},
fmt::LowerHex,
diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index d79c391..04fc36e 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -2,6 +2,7 @@ use cuda_types::{ CUdevice, CUdevice_attribute, CUfunction, CUjit_option, CUmodule, CUresult, CUuuid, }; use paste::paste; +use side_by_side::CudaDynamicFns; use std::io; use std::{ collections::HashMap, env, error::Error, ffi::c_void, fs, path::PathBuf, ptr::NonNull, rc::Rc, @@ -10,47 +11,50 @@ use std::{ #[macro_use] extern crate lazy_static; +extern crate cuda_types; macro_rules! extern_redirect { - ($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path) => { - #[no_mangle] - pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - let original_fn = |fn_ptr| { - let typed_fn = unsafe { std::mem::transmute::<_, extern "system" fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) }; - typed_fn($( $arg_id ),*) - }; - let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| { - (paste! { format :: [<write_ $fn_name>] }) ( - writer - $(,$arg_id)* - ) - }); - crate::handle_cuda_function_call(stringify!($fn_name), original_fn, get_formatted_args) - } + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => { + $( + #[no_mangle] + pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { + let original_fn = |dynamic_fns: &mut crate::side_by_side::CudaDynamicFns| { + dynamic_fns.$fn_name($( $arg_id ),*) + }; + let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| { + (paste! { format :: [<write_ $fn_name>] }) ( + writer + $(,$arg_id)* + ) + }); + crate::handle_cuda_function_call(stringify!($fn_name), original_fn, get_formatted_args) + } + )* }; } macro_rules! extern_redirect_with_post { - ($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path) => { - #[no_mangle] - pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - let original_fn = |fn_ptr| { - let typed_fn = unsafe { std::mem::transmute::<_, extern "system" fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) }; - typed_fn($( $arg_id ),*) - }; - let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| { - (paste! { format :: [<write_ $fn_name>] }) ( - writer - $(,$arg_id)* + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => { + $( + #[no_mangle] + pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { + let original_fn = |dynamic_fns: &mut crate::side_by_side::CudaDynamicFns| { + dynamic_fns.$fn_name($( $arg_id ),*) + }; + let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| { + (paste! { format :: [<write_ $fn_name>] }) ( + writer + $(,$arg_id)* + ) + }); + crate::handle_cuda_function_call_with_probes( + stringify!($fn_name), + || (), original_fn, + get_formatted_args, + move |logger, state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , logger, state, cuda_result ) ) - }); - crate::handle_cuda_function_call_with_probes( - stringify!($fn_name), - || (), original_fn, - get_formatted_args, - move |logger, state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , logger, state, cuda_result ) - ) - } + } + )* }; } @@ -77,6 +81,7 @@ mod log; #[cfg_attr(windows, path = "os_win.rs")] #[cfg_attr(not(windows), path = "os_unix.rs")] mod os; +mod side_by_side; mod trace; lazy_static! { @@ -127,7 +132,8 @@ impl<T> LateInit<T> { struct GlobalDelayedState { settings: Settings, - libcuda_handle: NonNull<c_void>, + libcuda: CudaDynamicFns, + side_by_side_lib: Option<CudaDynamicFns>, cuda_state: trace::StateTracker, } @@ -139,9 +145,8 @@ impl GlobalDelayedState { ) -> (LateInit<Self>, log::FunctionLogger<'a>) { let (mut fn_logger, settings) = factory.get_first_logger_and_init_settings(func, arguments_writer); - let maybe_libcuda_handle = unsafe { os::load_cuda_library(&settings.libcuda_path) }; - let libcuda_handle = match NonNull::new(maybe_libcuda_handle) { - Some(h) => h, + let libcuda = match unsafe { CudaDynamicFns::load_library(&settings.libcuda_path) } { + Some(libcuda) => libcuda, None => { fn_logger.log(log::LogEntry::ErrorBox( format!("Invalid CUDA library at path {}", &settings.libcuda_path).into(), @@ -149,11 +154,30 @@ impl GlobalDelayedState { return (LateInit::Error, fn_logger); } }; + let side_by_side_lib = settings + .side_by_side_path + .as_ref() + .and_then(|side_by_side_path| { + match unsafe { CudaDynamicFns::load_library(&*side_by_side_path) } { + Some(fns) => Some(fns), + None => { + fn_logger.log(log::LogEntry::ErrorBox( + format!( + "Invalid side-by-side CUDA library at path {}", + &side_by_side_path + ) + .into(), + )); + None + } + } + }); let cuda_state = trace::StateTracker::new(&settings); let delayed_state = GlobalDelayedState { settings, - libcuda_handle, + libcuda, cuda_state, + side_by_side_lib, }; (LateInit::Success(delayed_state), fn_logger) } @@ -163,6 +187,7 @@ struct Settings { dump_dir: Option<PathBuf>, libcuda_path: String, override_cc_major: Option<u32>, + side_by_side_path: Option<String>, } impl Settings { @@ -179,7 +204,7 @@ impl Settings { None } }; - let libcuda_path = match env::var("ZLUDA_DUMP_LIBCUDA_FILE") { + let libcuda_path = match env::var("ZLUDA_CUDA_LIB") { Err(env::VarError::NotPresent) => os::LIBCUDA_DEFAULT_PATH.to_owned(), Err(e) => { logger.log(log::LogEntry::ErrorBox(Box::new(e) as _)); @@ -201,10 +226,19 @@ impl Settings { Ok(cc) => Some(cc), }, }; + let side_by_side_path = match env::var("ZLUDA_SIDE_BY_SIDE_LIB") { + Err(env::VarError::NotPresent) => None, + Err(e) => { + logger.log(log::LogEntry::ErrorBox(Box::new(e) as _)); + None + } + Ok(env_string) => Some(env_string), + }; Settings { dump_dir, libcuda_path, override_cc_major, + side_by_side_path, } } @@ -241,7 +275,7 @@ pub struct ModuleDump { fn handle_cuda_function_call( func: &'static str, - original_cuda_fn: impl FnOnce(NonNull<c_void>) -> CUresult, + original_cuda_fn: impl FnOnce(&mut CudaDynamicFns) -> Option<CUresult>, arguments_writer: Box<dyn FnMut(&mut dyn std::io::Write) -> std::io::Result<()>>, ) -> CUresult { handle_cuda_function_call_with_probes( @@ -256,7 +290,7 @@ fn handle_cuda_function_call( fn handle_cuda_function_call_with_probes<T, PostFn>( func: &'static str, pre_probe: impl FnOnce() -> T, - original_cuda_fn: impl FnOnce(NonNull<c_void>) -> CUresult, + original_cuda_fn: impl FnOnce(&mut CudaDynamicFns) -> Option<CUresult>, arguments_writer: Box<dyn FnMut(&mut dyn std::io::Write) -> std::io::Result<()>>, post_probe: PostFn, ) -> CUresult @@ -283,13 +317,18 @@ where (logger, global_state.delayed_state.as_mut().unwrap()) } }; - let name = std::ffi::CString::new(func).unwrap(); - let fn_ptr = - unsafe { os::get_proc_address(delayed_state.libcuda_handle.as_ptr(), name.as_c_str()) }; - let fn_ptr = NonNull::new(fn_ptr).unwrap(); let pre_result = pre_probe(); - let cu_result = original_cuda_fn(fn_ptr); - logger.result = Some(cu_result); + let maybe_cu_result = original_cuda_fn(&mut delayed_state.libcuda); + let cu_result = match maybe_cu_result { + Some(result) => result, + None => { + logger.log(log::LogEntry::ErrorBox( + format!("No function {} in the underlying CUDA library", func).into(), + )); + CUresult::CUDA_ERROR_UNKNOWN + } + }; + logger.result = maybe_cu_result; post_probe( &mut logger, &mut delayed_state.cuda_state, diff --git a/zluda_dump/src/os_unix.rs b/zluda_dump/src/os_unix.rs index 3b37e74..e1e516b 100644 --- a/zluda_dump/src/os_unix.rs +++ b/zluda_dump/src/os_unix.rs @@ -4,7 +4,7 @@ use std::mem; pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = b"/usr/lib/x86_64-linux-gnu/libcuda.so.1\0";
-pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
+pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path = CString::new(libcuda_path).unwrap();
libc::dlopen(
libcuda_path.as_ptr() as *const _,
diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index c138cc0..ef3da44 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -73,7 +73,7 @@ impl PlatformLibrary { }
}
-pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
+pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path_uf16 = libcuda_path
.encode_utf16()
.chain(std::iter::once(0))
diff --git a/zluda_dump/src/side_by_side.rs b/zluda_dump/src/side_by_side.rs new file mode 100644 index 0000000..33954b8 --- /dev/null +++ b/zluda_dump/src/side_by_side.rs @@ -0,0 +1,77 @@ +use cuda_base::cuda_function_declarations;
+use std::ffi::CStr;
+use std::mem;
+use std::ptr;
+use std::ptr::NonNull;
+use std::{marker::PhantomData, os::raw::c_void};
+
+use crate::os;
+
+struct DynamicFn<T> {
+ pointer: usize,
+ _marker: PhantomData<T>,
+}
+
+impl<T> Default for DynamicFn<T> {
+ fn default() -> Self {
+ DynamicFn {
+ pointer: 0,
+ _marker: PhantomData,
+ }
+ }
+}
+
+impl<T> DynamicFn<T> {
+ unsafe fn get(&mut self, lib: *mut c_void, name: &[u8]) -> Option<T> {
+ match self.pointer {
+ 0 => {
+ let addr = os::get_proc_address(lib, CStr::from_bytes_with_nul_unchecked(name));
+ if addr == ptr::null_mut() {
+ self.pointer = 1;
+ return None;
+ } else {
+ self.pointer = addr as _;
+ }
+ }
+ 1 => return None,
+ _ => {}
+ }
+ Some(mem::transmute_copy(&self.pointer))
+ }
+}
+
+pub(crate) struct CudaDynamicFns {
+ lib_handle: NonNull<::std::ffi::c_void>,
+ fn_table: CudaFnTable,
+}
+
+impl CudaDynamicFns {
+ pub(crate) unsafe fn load_library(path: &str) -> Option<Self> {
+ let lib_handle = NonNull::new(os::load_library(path));
+ lib_handle.map(|lib_handle| CudaDynamicFns {
+ lib_handle,
+ fn_table: CudaFnTable::default(),
+ })
+ }
+}
+
+macro_rules! emit_cuda_fn_table {
+ ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => {
+ #[derive(Default)]
+ struct CudaFnTable {
+ $($fn_name: DynamicFn<extern $abi fn ( $($arg_id : $arg_type),* ) -> $ret_type>),*
+ }
+
+ impl CudaDynamicFns {
+ $(
+ #[allow(dead_code)]
+ pub(crate) fn $fn_name(&mut self, $($arg_id : $arg_type),*) -> Option<$ret_type> {
+ let func = unsafe { self.fn_table.$fn_name.get(self.lib_handle.as_ptr(), concat!(stringify!($fn_name), "\0").as_bytes()) };
+ func.map(|f| f($($arg_id),*) )
+ }
+ )*
+ }
+ };
+}
+
+cuda_function_declarations!(cuda_types, emit_cuda_fn_table, emit_cuda_fn_table, []);
|