aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/link.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src/impl/link.rs')
-rw-r--r--zluda/src/impl/link.rs44
1 files changed, 29 insertions, 15 deletions
diff --git a/zluda/src/impl/link.rs b/zluda/src/impl/link.rs
index 928180d..35b156a 100644
--- a/zluda/src/impl/link.rs
+++ b/zluda/src/impl/link.rs
@@ -3,10 +3,18 @@ use std::{
mem, ptr, slice,
};
-use crate::cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult};
+use hip_runtime_sys::{hipCtxGetDevice, hipError_t, hipGetDeviceProperties};
+
+use crate::{
+ cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult},
+ hip_call,
+};
+
+use super::module::{self, SpirvModule};
struct LinkState {
- modules: Vec<String>,
+ modules: Vec<SpirvModule>,
+ result: Option<Vec<u8>>,
}
pub(crate) unsafe fn create(
@@ -20,6 +28,7 @@ pub(crate) unsafe fn create(
}
let state = Box::new(LinkState {
modules: Vec::new(),
+ result: None,
});
*state_out = mem::transmute(state);
CUresult::CUDA_SUCCESS
@@ -34,31 +43,36 @@ pub(crate) unsafe fn add_data(
num_options: u32,
options: *mut CUjit_option,
option_values: *mut *mut c_void,
-) -> CUresult {
+) -> Result<(), hipError_t> {
if state == ptr::null_mut() {
- return CUresult::CUDA_ERROR_INVALID_VALUE;
+ return Err(hipError_t::hipErrorInvalidValue);
}
let state: *mut LinkState = mem::transmute(state);
let state = &mut *state;
// V-RAY specific hack
if state.modules.len() == 2 {
- return CUresult::CUDA_SUCCESS;
+ return Err(hipError_t::hipSuccess);
}
- let ptx = slice::from_raw_parts(data as *mut u8, size);
- state.modules.push(
- CStr::from_bytes_with_nul_unchecked(ptx)
- .to_string_lossy()
- .to_string(),
- );
- CUresult::CUDA_SUCCESS
+ let spirv_data = SpirvModule::new_raw(data as *const _)?;
+ state.modules.push(spirv_data);
+ Ok(())
}
-pub(crate) fn complete(
+pub(crate) unsafe fn complete(
state: CUlinkState,
cubin_out: *mut *mut c_void,
size_out: *mut usize,
-) -> CUresult {
- CUresult::CUDA_SUCCESS
+) -> Result<(), hipError_t> {
+ let mut dev = 0;
+ hip_call! { hipCtxGetDevice(&mut dev) };
+ let mut props = unsafe { mem::zeroed() };
+ hip_call! { hipGetDeviceProperties(&mut props, dev) };
+ let state: &LinkState = mem::transmute(state);
+ let spirv_bins = state.modules.iter().map(|m| &m.binaries[..]);
+ let should_link_ptx_impl = state.modules.iter().find_map(|m| m.should_link_ptx_impl);
+ let arch_binary = module::compile_amd(&props, spirv_bins, should_link_ptx_impl)
+ .map_err(|_| hipError_t::hipErrorUnknown)?;
+ Ok(())
}
pub(crate) unsafe fn destroy(state: CUlinkState) -> CUresult {