diff options
Diffstat (limited to 'zluda/src/impl/module.rs')
-rw-r--r-- | zluda/src/impl/module.rs | 30 |
1 files changed, 18 insertions, 12 deletions
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 6575d96..ba09869 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -87,7 +87,7 @@ pub fn load_data_impl(pmod: *mut CUmodule, spirv_data: SpirvModule) -> Result<() let err = unsafe { hipGetDeviceProperties(&mut props, dev) }; let arch_binary = compile_amd( &props, - &spirv_data.binaries[..], + &[&spirv_data.binaries[..]], spirv_data.should_link_ptx_impl, ) .map_err(|_| hipError_t::hipErrorUnknown)?; @@ -115,7 +115,7 @@ const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_"; fn compile_amd( device_pros: &hipDeviceProp_t, - spirv_il: &[u32], + spirv_il: &[&[u32]], ptx_lib: Option<(&'static [u8], &'static [u8])>, ) -> io::Result<Vec<u8>> { let null_terminator = device_pros @@ -133,24 +133,30 @@ fn compile_amd( return Err(io::Error::new(io::ErrorKind::Other, "")); }; let dir = tempfile::tempdir()?; - let mut spirv = NamedTempFile::new_in(&dir)?; - let llvm = NamedTempFile::new_in(&dir)?; - let spirv_il_u8 = unsafe { - slice::from_raw_parts( - spirv_il.as_ptr() as *const u8, - spirv_il.len() * mem::size_of::<u32>(), - ) - }; - spirv.write_all(spirv_il_u8)?; + let spirv_files = spirv_il + .iter() + .map(|spirv| { + let mut spirv = NamedTempFile::new_in(&dir)?; + let spirv_il_u8 = unsafe { + slice::from_raw_parts( + spirv_il.as_ptr() as *const u8, + spirv_il.len() * mem::size_of::<u32>(), + ) + }; + spirv.write_all(spirv_il_u8)?; + Ok::<_, io::Error>(spirv) + }) + .collect::<Result<Vec<_>, _>>()?; let llvm_spirv_path = match env::var("LLVM_SPIRV") { Ok(path) => Cow::Owned(path), Err(_) => Cow::Borrowed(LLVM_SPIRV), }; + let llvm = NamedTempFile::new_in(&dir)?; let to_llvm_cmd = Command::new(&*llvm_spirv_path) .arg("-r") .arg("-o") .arg(llvm.path()) - .arg(spirv.path()) + .args(spirv_files.iter().map(|f| f.path())) .status()?; assert!(to_llvm_cmd.success()); if cfg!(debug_assertions) { |