diff options
-rw-r--r-- | ptx/src/test/spirv_run/call.ptx | 4 | ||||
-rw-r--r-- | ptx/src/translate.rs | 6 | ||||
-rw-r--r-- | zluda/src/impl/module.rs | 23 |
3 files changed, 18 insertions, 15 deletions
diff --git a/ptx/src/test/spirv_run/call.ptx b/ptx/src/test/spirv_run/call.ptx index f2ac39c..537fce2 100644 --- a/ptx/src/test/spirv_run/call.ptx +++ b/ptx/src/test/spirv_run/call.ptx @@ -2,7 +2,7 @@ .target sm_30 .address_size 64 -.func (.param.u64 output) incr (.param.u64 input); +.visible .func (.param.u64 output) incr (.param.u64 input); .visible .entry call( .param .u64 input, @@ -26,7 +26,7 @@ ret; } -.func (.param .u64 output) incr( +.visible .func (.param .u64 output) incr( .param .u64 input ) { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 13c578b..2af7534 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -543,10 +543,10 @@ fn emit_directives<'input>( let f_body = match &f.body {
Some(f) => f,
None => {
- if f.linkage == ast::LinkingDirective::NONE {
- continue;
- } else {
+ if f.linkage.contains(ast::LinkingDirective::EXTERN) {
&empty_body
+ } else {
+ continue;
}
}
};
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 6234909..9732ec9 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -16,6 +16,7 @@ use hip_runtime_sys::{ use tempfile::NamedTempFile; use crate::cuda::CUmodule; +use crate::hip_call; pub struct SpirvModule { pub binaries: Vec<u32>, @@ -73,28 +74,31 @@ pub(crate) fn load_data( module: *mut CUmodule, image: *const std::ffi::c_void, ) -> Result<(), hipError_t> { + if image == ptr::null() { + return Err(hipError_t::hipErrorInvalidValue); + } + if unsafe { *(image as *const u32) } == 0x464c457f { + return match unsafe { hipModuleLoadData(module as _, image) } { + hipError_t::hipSuccess => Ok(()), + e => Err(e), + }; + } let spirv_data = SpirvModule::new_raw(image as *const _)?; load_data_impl(module, spirv_data) } pub fn load_data_impl(pmod: *mut CUmodule, spirv_data: SpirvModule) -> Result<(), hipError_t> { let mut dev = 0; - let err = unsafe { hipCtxGetDevice(&mut dev) }; - if err != hipError_t::hipSuccess { - return Err(err); - } + hip_call! { hipCtxGetDevice(&mut dev) }; let mut props = unsafe { mem::zeroed() }; - let err = unsafe { hipGetDeviceProperties(&mut props, dev) }; + hip_call! { hipGetDeviceProperties(&mut props, dev) }; let arch_binary = compile_amd( &props, iter::once(&spirv_data.binaries[..]), spirv_data.should_link_ptx_impl, ) .map_err(|_| hipError_t::hipErrorUnknown)?; - let err = unsafe { hipModuleLoadData(pmod as _, arch_binary.as_ptr() as _) }; - if err != hipError_t::hipSuccess { - return Err(err); - } + hip_call! { hipModuleLoadData(pmod as _, arch_binary.as_ptr() as _) }; Ok(()) } @@ -172,7 +176,6 @@ pub(crate) fn compile_amd<'a>( llvm_link.push("llvm-link"); let mut linker_cmd = Command::new(&llvm_link); linker_cmd - .arg("--only-needed") .arg("-o") .arg(linked_binary.path()) .args(llvm_files.iter().map(|f| f.path())) |