diff options
-rw-r--r-- | zluda/src/impl/device.rs | 19 | ||||
-rw-r--r-- | zluda/src/impl/module.rs | 114 |
2 files changed, 84 insertions, 49 deletions
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 7e65272..5fdb24b 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -27,6 +27,7 @@ pub struct Device { pub primary_context: context::Context, pub allocations: HashSet<*mut c_void>, pub is_amd: bool, + pub name: String, } unsafe impl Send for Device {} @@ -44,6 +45,12 @@ impl Device { let queue = ocl_core::create_command_queue(&ctx, ocl_dev, None)?; let primary_context = context::Context::new(context::ContextData::new(0, true, ptr::null_mut())?); + let props = ocl_core::get_device_info(ocl_dev, ocl_core::DeviceInfo::Name)?; + let name = if let ocl_core::DeviceInfoResult::Name(name) = props { + Ok(name) + } else { + Err(CUresult::CUDA_ERROR_UNKNOWN) + }?; Ok(Self { index: Index(idx as c_int), ocl_base: ocl_dev, @@ -52,6 +59,7 @@ impl Device { primary_context, allocations: HashSet::new(), is_amd, + name, }) } @@ -83,14 +91,7 @@ pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUres if name == ptr::null_mut() || len < 0 { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let name_string = GlobalState::lock_device(dev_idx, |dev| { - let props = ocl_core::get_device_info(dev.ocl_base, ocl_core::DeviceInfo::Name)?; - if let ocl_core::DeviceInfoResult::Name(name) = props { - Ok(name) - } else { - Err(CUresult::CUDA_ERROR_UNKNOWN) - } - })??; + let name_string = GlobalState::lock_device(dev_idx, |dev| dev.name.clone())?; let mut dst_null_pos = cmp::min((len - 1) as usize, name_string.len()); unsafe { std::ptr::copy_nonoverlapping(name_string.as_ptr() as *const _, name, dst_null_pos) }; if name_string.len() + PROJECT_URL_SUFFIX_LONG.len() < (len as usize) { @@ -179,7 +180,7 @@ pub fn get_attribute( CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR => { GlobalState::lock_device(dev_idx, |dev| { if !dev.is_amd { - 8i32 * 7 // correct for GEN9 + 7 // correct for GEN9 } else { 4i32 * 32 // probably correct for RDNA } diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index f86e563..f2a453e 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -1,10 +1,11 @@ use std::{ + borrow::Cow, collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, - io::{self, Write}, + io::{self, Read, Write}, mem, os::raw::{c_char, c_int, c_uint}, path::PathBuf, @@ -106,9 +107,8 @@ impl SpirvModule { "oclc_wavefrontsize64_off.bc", ]; const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_"; - const AMDGPU_DEVICE: &'static str = "gfx1010"; - fn get_bitcode_paths() -> impl Iterator<Item = PathBuf> { + fn get_bitcode_paths(device_name: &str) -> impl Iterator<Item = PathBuf> { let generic_paths = Self::AMDGPU_BITCODE.iter().map(|x| { let mut path = PathBuf::from(Self::AMDGPU); path.push("amdgcn"); @@ -122,19 +122,27 @@ impl SpirvModule { additional_path.push(format!( "{}{}{}", Self::AMDGPU_BITCODE_DEVICE_PREFIX, - &Self::AMDGPU_DEVICE[3..], + &device_name[3..], ".bc" )); generic_paths.chain(std::iter::once(additional_path)) } #[cfg(not(target_os = "linux"))] - fn compile_amd(spirv_il: &[u8], ptx_lib: Option<&'static [u8]>) -> io::Result<()> { - Ok(()) + fn compile_amd( + device_name: &str, + spirv_il: &[u8], + ptx_lib: Option<(&'static [u8], &'static [u8])>, + ) -> io::Result<Vec<u8>> { + unimplemented!() } #[cfg(target_os = "linux")] - fn compile_amd(spirv_il: &[u8], ptx_lib: Option<&'static [u8]>) -> io::Result<()> { + fn compile_amd( + device_name: &str, + spirv_il: &[u8], + ptx_lib: Option<(&'static [u8], &'static [u8])>, + ) -> io::Result<Vec<u8>> { let dir = tempfile::tempdir()?; let mut spirv = NamedTempFile::new_in(&dir)?; let llvm = NamedTempFile::new_in(&dir)?; @@ -156,20 +164,20 @@ impl SpirvModule { .arg("-o") .arg(linked_binary.path()) .arg(llvm.path()) - .args(Self::get_bitcode_paths()); + .args(Self::get_bitcode_paths(device_name)); if cfg!(debug_assertions) { linker_cmd.arg("-v"); } let status = linker_cmd.status()?; assert!(status.success()); let mut ptx_lib_bitcode = NamedTempFile::new_in(&dir)?; - let compiled_binary = NamedTempFile::new_in(&dir)?; + let mut compiled_binary = NamedTempFile::new_in(&dir)?; let mut cland_exe = PathBuf::from(Self::AMDGPU); cland_exe.push("bin"); cland_exe.push("clang"); let mut compiler_cmd = Command::new(&cland_exe); compiler_cmd - .arg(format!("-mcpu={}", Self::AMDGPU_DEVICE)) + .arg(format!("-mcpu={}", device_name)) .arg("-O3") .arg("-Xlinker") .arg("--no-undefined") @@ -178,7 +186,7 @@ impl SpirvModule { .arg("-o") .arg(compiled_binary.path()) .arg(linked_binary.path()); - if let Some(bitcode) = ptx_lib { + if let Some((_, bitcode)) = ptx_lib { ptx_lib_bitcode.write_all(bitcode)?; compiler_cmd.arg(ptx_lib_bitcode.path()); }; @@ -187,40 +195,30 @@ impl SpirvModule { } let status = compiler_cmd.status()?; assert!(status.success()); - Ok(()) + let mut result = Vec::new(); + compiled_binary.read_to_end(&mut result)?; + Ok(result) } - pub fn compile<'a>( - &self, + fn compile_intel<'a>( ctx: &ocl_core::Context, dev: &ocl_core::DeviceId, - ) -> Result<ocl_core::Program, CUresult> { - let byte_il = unsafe { - slice::from_raw_parts( - self.binaries.as_ptr() as *const u8, - self.binaries.len() * mem::size_of::<u32>(), - ) - }; + byte_il: &'a [u8], + build_options: &CString, + ptx_lib: Option<(&'static [u8], &'static [u8])>, + ) -> ocl_core::Result<ocl_core::Program> { let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?; - let main_module = match self.should_link_ptx_impl { + Ok(match ptx_lib { None => { - Self::compile_amd(byte_il, None).unwrap(); - ocl_core::build_program( - &main_module, - Some(&[dev]), - &self.build_options, - None, - None, - )?; + ocl_core::build_program(&main_module, Some(&[dev]), build_options, None, None)?; main_module } - Some((ptx_impl_intel, ptx_impl_amd)) => { - Self::compile_amd(byte_il, Some(ptx_impl_amd)).unwrap(); + Some((ptx_impl_intel, _)) => { let ptx_impl_prog = ocl_core::create_program_with_il(ctx, ptx_impl_intel, None)?; ocl_core::compile_program( &main_module, Some(&[dev]), - &self.build_options, + build_options, &[], &[], None, @@ -230,7 +228,7 @@ impl SpirvModule { ocl_core::compile_program( &ptx_impl_prog, Some(&[dev]), - &self.build_options, + build_options, &[], &[], None, @@ -240,15 +238,43 @@ impl SpirvModule { ocl_core::link_program( ctx, Some(&[dev]), - &self.build_options, + build_options, &[&main_module, &ptx_impl_prog], None, None, None, )? } + }) + } + + pub fn compile<'a>( + &self, + ctx: &ocl_core::Context, + dev: &ocl_core::DeviceId, + device_name: &str, + is_amd: bool, + ) -> Result<ocl_core::Program, CUresult> { + let byte_il = unsafe { + slice::from_raw_parts( + self.binaries.as_ptr() as *const u8, + self.binaries.len() * mem::size_of::<u32>(), + ) + }; + let ocl_program = if is_amd { + let binary_prog = + Self::compile_amd(device_name, byte_il, self.should_link_ptx_impl).unwrap(); + ocl_core::create_program_with_binary(ctx, &[dev], &[&binary_prog[..]])? + } else { + Self::compile_intel( + ctx, + dev, + byte_il, + &self.build_options, + self.should_link_ptx_impl, + )? }; - Ok(main_module) + Ok(ocl_program) } } @@ -268,9 +294,12 @@ pub fn get_function( hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Vacant(entry) => { let new_module = CompiledModule { - base: module - .spirv - .compile(&device.ocl_context, &device.ocl_base)?, + base: module.spirv.compile( + &device.ocl_context, + &device.ocl_base, + &device.name, + device.is_amd, + )?, kernels: HashMap::new(), }; entry.insert(new_module) @@ -340,7 +369,12 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result< pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { let module = GlobalState::lock_current_context(|ctx| { let device = unsafe { &mut *ctx.device }; - let l0_module = spirv_data.compile(&device.ocl_context, &device.ocl_base)?; + let l0_module = spirv_data.compile( + &device.ocl_context, + &device.ocl_base, + &device.name, + device.is_amd, + )?; let mut device_binaries = HashMap::new(); let compiled_module = CompiledModule { base: l0_module, |