aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/module.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src/impl/module.rs')
-rw-r--r--zluda/src/impl/module.rs30
1 files changed, 18 insertions, 12 deletions
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs
index 6575d96..ba09869 100644
--- a/zluda/src/impl/module.rs
+++ b/zluda/src/impl/module.rs
@@ -87,7 +87,7 @@ pub fn load_data_impl(pmod: *mut CUmodule, spirv_data: SpirvModule) -> Result<()
let err = unsafe { hipGetDeviceProperties(&mut props, dev) };
let arch_binary = compile_amd(
&props,
- &spirv_data.binaries[..],
+ &[&spirv_data.binaries[..]],
spirv_data.should_link_ptx_impl,
)
.map_err(|_| hipError_t::hipErrorUnknown)?;
@@ -115,7 +115,7 @@ const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
fn compile_amd(
device_pros: &hipDeviceProp_t,
- spirv_il: &[u32],
+ spirv_il: &[&[u32]],
ptx_lib: Option<(&'static [u8], &'static [u8])>,
) -> io::Result<Vec<u8>> {
let null_terminator = device_pros
@@ -133,24 +133,30 @@ fn compile_amd(
return Err(io::Error::new(io::ErrorKind::Other, ""));
};
let dir = tempfile::tempdir()?;
- let mut spirv = NamedTempFile::new_in(&dir)?;
- let llvm = NamedTempFile::new_in(&dir)?;
- let spirv_il_u8 = unsafe {
- slice::from_raw_parts(
- spirv_il.as_ptr() as *const u8,
- spirv_il.len() * mem::size_of::<u32>(),
- )
- };
- spirv.write_all(spirv_il_u8)?;
+ let spirv_files = spirv_il
+ .iter()
+ .map(|spirv| {
+ let mut spirv = NamedTempFile::new_in(&dir)?;
+ let spirv_il_u8 = unsafe {
+ slice::from_raw_parts(
+ spirv_il.as_ptr() as *const u8,
+ spirv_il.len() * mem::size_of::<u32>(),
+ )
+ };
+ spirv.write_all(spirv_il_u8)?;
+ Ok::<_, io::Error>(spirv)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
let llvm_spirv_path = match env::var("LLVM_SPIRV") {
Ok(path) => Cow::Owned(path),
Err(_) => Cow::Borrowed(LLVM_SPIRV),
};
+ let llvm = NamedTempFile::new_in(&dir)?;
let to_llvm_cmd = Command::new(&*llvm_spirv_path)
.arg("-r")
.arg("-o")
.arg(llvm.path())
- .arg(spirv.path())
+ .args(spirv_files.iter().map(|f| f.path()))
.status()?;
assert!(to_llvm_cmd.success());
if cfg!(debug_assertions) {