diff options
-rw-r--r-- | zluda/src/cuda.rs | 5 | ||||
-rw-r--r-- | zluda/src/impl/link.rs | 44 | ||||
-rw-r--r-- | zluda/src/impl/module.rs | 9 |
3 files changed, 36 insertions, 22 deletions
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs index bceb8bc..e7f5e42 100644 --- a/zluda/src/cuda.rs +++ b/zluda/src/cuda.rs @@ -2565,6 +2565,7 @@ pub unsafe extern "system" fn cuLinkAddData_v2( options, optionValues, ) + .encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2580,12 +2581,12 @@ pub extern "system" fn cuLinkAddFile_v2( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuLinkComplete( +pub unsafe extern "system" fn cuLinkComplete( state: CUlinkState, cubinOut: *mut *mut ::std::os::raw::c_void, sizeOut: *mut usize, ) -> CUresult { - r#impl::link::complete(state, cubinOut, sizeOut) + r#impl::link::complete(state, cubinOut, sizeOut).encuda() } #[cfg_attr(not(test), no_mangle)] diff --git a/zluda/src/impl/link.rs b/zluda/src/impl/link.rs index 928180d..35b156a 100644 --- a/zluda/src/impl/link.rs +++ b/zluda/src/impl/link.rs @@ -3,10 +3,18 @@ use std::{ mem, ptr, slice, }; -use crate::cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult}; +use hip_runtime_sys::{hipCtxGetDevice, hipError_t, hipGetDeviceProperties}; + +use crate::{ + cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult}, + hip_call, +}; + +use super::module::{self, SpirvModule}; struct LinkState { - modules: Vec<String>, + modules: Vec<SpirvModule>, + result: Option<Vec<u8>>, } pub(crate) unsafe fn create( @@ -20,6 +28,7 @@ pub(crate) unsafe fn create( } let state = Box::new(LinkState { modules: Vec::new(), + result: None, }); *state_out = mem::transmute(state); CUresult::CUDA_SUCCESS @@ -34,31 +43,36 @@ pub(crate) unsafe fn add_data( num_options: u32, options: *mut CUjit_option, option_values: *mut *mut c_void, -) -> CUresult { +) -> Result<(), hipError_t> { if state == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(hipError_t::hipErrorInvalidValue); } let state: *mut LinkState = mem::transmute(state); let state = &mut *state; // V-RAY specific hack if state.modules.len() == 2 { - return CUresult::CUDA_SUCCESS; + return Err(hipError_t::hipSuccess); } - let ptx = slice::from_raw_parts(data as *mut u8, size); - state.modules.push( - CStr::from_bytes_with_nul_unchecked(ptx) - .to_string_lossy() - .to_string(), - ); - CUresult::CUDA_SUCCESS + let spirv_data = SpirvModule::new_raw(data as *const _)?; + state.modules.push(spirv_data); + Ok(()) } -pub(crate) fn complete( +pub(crate) unsafe fn complete( state: CUlinkState, cubin_out: *mut *mut c_void, size_out: *mut usize, -) -> CUresult { - CUresult::CUDA_SUCCESS +) -> Result<(), hipError_t> { + let mut dev = 0; + hip_call! { hipCtxGetDevice(&mut dev) }; + let mut props = unsafe { mem::zeroed() }; + hip_call! { hipGetDeviceProperties(&mut props, dev) }; + let state: &LinkState = mem::transmute(state); + let spirv_bins = state.modules.iter().map(|m| &m.binaries[..]); + let should_link_ptx_impl = state.modules.iter().find_map(|m| m.should_link_ptx_impl); + let arch_binary = module::compile_amd(&props, spirv_bins, should_link_ptx_impl) + .map_err(|_| hipError_t::hipErrorUnknown)?; + Ok(()) } pub(crate) unsafe fn destroy(state: CUlinkState) -> CUresult { diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 1cf0a5a..5560526 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -7,7 +7,7 @@ use std::ops::Add; use std::os::raw::c_char; use std::path::{Path, PathBuf}; use std::process::Command; -use std::{env, fs, mem, ptr, slice}; +use std::{env, fs, iter, mem, ptr, slice}; use hip_runtime_sys::{ hipCtxGetCurrent, hipCtxGetDevice, hipDeviceGetAttribute, hipDeviceGetName, hipDeviceProp_t, @@ -87,7 +87,7 @@ pub fn load_data_impl(pmod: *mut CUmodule, spirv_data: SpirvModule) -> Result<() let err = unsafe { hipGetDeviceProperties(&mut props, dev) }; let arch_binary = compile_amd( &props, - &[&spirv_data.binaries[..]], + iter::once(&spirv_data.binaries[..]), spirv_data.should_link_ptx_impl, ) .map_err(|_| hipError_t::hipErrorUnknown)?; @@ -113,9 +113,9 @@ const AMDGPU_BITCODE: [&'static str; 8] = [ ]; const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_"; -fn compile_amd( +pub(crate) fn compile_amd<'a>( device_pros: &hipDeviceProp_t, - spirv_il: &[&[u32]], + spirv_il: impl Iterator<Item = &'a [u32]>, ptx_lib: Option<(&'static [u8], &'static [u8])>, ) -> io::Result<Vec<u8>> { let null_terminator = device_pros @@ -134,7 +134,6 @@ fn compile_amd( }; let dir = tempfile::tempdir()?; let spirv_files = spirv_il - .iter() .map(|spirv| { let mut spirv_file = NamedTempFile::new_in(&dir)?; let spirv_u8 = unsafe { |