aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-08-06 13:19:55 +0200
committerAndrzej Janik <[email protected]>2021-08-06 13:19:55 +0200
commit479014a783488a225ce320e784d6e9cdf18190ba (patch)
tree8b53e893be809971b1ff2752d7c63c5a52874302
parent5bfc2a56b92ffdbd80718aa7959f97578c579d8d (diff)
downloadZLUDA-479014a783488a225ce320e784d6e9cdf18190ba.tar.gz
ZLUDA-479014a783488a225ce320e784d6e9cdf18190ba.zip
Wire up AMD compilation
-rw-r--r--zluda/src/impl/device.rs19
-rw-r--r--zluda/src/impl/module.rs114
2 files changed, 84 insertions, 49 deletions
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index 7e65272..5fdb24b 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -27,6 +27,7 @@ pub struct Device {
pub primary_context: context::Context,
pub allocations: HashSet<*mut c_void>,
pub is_amd: bool,
+ pub name: String,
}
unsafe impl Send for Device {}
@@ -44,6 +45,12 @@ impl Device {
let queue = ocl_core::create_command_queue(&ctx, ocl_dev, None)?;
let primary_context =
context::Context::new(context::ContextData::new(0, true, ptr::null_mut())?);
+ let props = ocl_core::get_device_info(ocl_dev, ocl_core::DeviceInfo::Name)?;
+ let name = if let ocl_core::DeviceInfoResult::Name(name) = props {
+ Ok(name)
+ } else {
+ Err(CUresult::CUDA_ERROR_UNKNOWN)
+ }?;
Ok(Self {
index: Index(idx as c_int),
ocl_base: ocl_dev,
@@ -52,6 +59,7 @@ impl Device {
primary_context,
allocations: HashSet::new(),
is_amd,
+ name,
})
}
@@ -83,14 +91,7 @@ pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUres
if name == ptr::null_mut() || len < 0 {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let name_string = GlobalState::lock_device(dev_idx, |dev| {
- let props = ocl_core::get_device_info(dev.ocl_base, ocl_core::DeviceInfo::Name)?;
- if let ocl_core::DeviceInfoResult::Name(name) = props {
- Ok(name)
- } else {
- Err(CUresult::CUDA_ERROR_UNKNOWN)
- }
- })??;
+ let name_string = GlobalState::lock_device(dev_idx, |dev| dev.name.clone())?;
let mut dst_null_pos = cmp::min((len - 1) as usize, name_string.len());
unsafe { std::ptr::copy_nonoverlapping(name_string.as_ptr() as *const _, name, dst_null_pos) };
if name_string.len() + PROJECT_URL_SUFFIX_LONG.len() < (len as usize) {
@@ -179,7 +180,7 @@ pub fn get_attribute(
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR => {
GlobalState::lock_device(dev_idx, |dev| {
if !dev.is_amd {
- 8i32 * 7 // correct for GEN9
+ 7 // correct for GEN9
} else {
4i32 * 32 // probably correct for RDNA
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs
index f86e563..f2a453e 100644
--- a/zluda/src/impl/module.rs
+++ b/zluda/src/impl/module.rs
@@ -1,10 +1,11 @@
use std::{
+ borrow::Cow,
collections::hash_map,
collections::HashMap,
ffi::c_void,
ffi::CStr,
ffi::CString,
- io::{self, Write},
+ io::{self, Read, Write},
mem,
os::raw::{c_char, c_int, c_uint},
path::PathBuf,
@@ -106,9 +107,8 @@ impl SpirvModule {
"oclc_wavefrontsize64_off.bc",
];
const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
- const AMDGPU_DEVICE: &'static str = "gfx1010";
- fn get_bitcode_paths() -> impl Iterator<Item = PathBuf> {
+ fn get_bitcode_paths(device_name: &str) -> impl Iterator<Item = PathBuf> {
let generic_paths = Self::AMDGPU_BITCODE.iter().map(|x| {
let mut path = PathBuf::from(Self::AMDGPU);
path.push("amdgcn");
@@ -122,19 +122,27 @@ impl SpirvModule {
additional_path.push(format!(
"{}{}{}",
Self::AMDGPU_BITCODE_DEVICE_PREFIX,
- &Self::AMDGPU_DEVICE[3..],
+ &device_name[3..],
".bc"
));
generic_paths.chain(std::iter::once(additional_path))
}
#[cfg(not(target_os = "linux"))]
- fn compile_amd(spirv_il: &[u8], ptx_lib: Option<&'static [u8]>) -> io::Result<()> {
- Ok(())
+ fn compile_amd(
+ device_name: &str,
+ spirv_il: &[u8],
+ ptx_lib: Option<(&'static [u8], &'static [u8])>,
+ ) -> io::Result<Vec<u8>> {
+ unimplemented!()
}
#[cfg(target_os = "linux")]
- fn compile_amd(spirv_il: &[u8], ptx_lib: Option<&'static [u8]>) -> io::Result<()> {
+ fn compile_amd(
+ device_name: &str,
+ spirv_il: &[u8],
+ ptx_lib: Option<(&'static [u8], &'static [u8])>,
+ ) -> io::Result<Vec<u8>> {
let dir = tempfile::tempdir()?;
let mut spirv = NamedTempFile::new_in(&dir)?;
let llvm = NamedTempFile::new_in(&dir)?;
@@ -156,20 +164,20 @@ impl SpirvModule {
.arg("-o")
.arg(linked_binary.path())
.arg(llvm.path())
- .args(Self::get_bitcode_paths());
+ .args(Self::get_bitcode_paths(device_name));
if cfg!(debug_assertions) {
linker_cmd.arg("-v");
}
let status = linker_cmd.status()?;
assert!(status.success());
let mut ptx_lib_bitcode = NamedTempFile::new_in(&dir)?;
- let compiled_binary = NamedTempFile::new_in(&dir)?;
+ let mut compiled_binary = NamedTempFile::new_in(&dir)?;
let mut cland_exe = PathBuf::from(Self::AMDGPU);
cland_exe.push("bin");
cland_exe.push("clang");
let mut compiler_cmd = Command::new(&cland_exe);
compiler_cmd
- .arg(format!("-mcpu={}", Self::AMDGPU_DEVICE))
+ .arg(format!("-mcpu={}", device_name))
.arg("-O3")
.arg("-Xlinker")
.arg("--no-undefined")
@@ -178,7 +186,7 @@ impl SpirvModule {
.arg("-o")
.arg(compiled_binary.path())
.arg(linked_binary.path());
- if let Some(bitcode) = ptx_lib {
+ if let Some((_, bitcode)) = ptx_lib {
ptx_lib_bitcode.write_all(bitcode)?;
compiler_cmd.arg(ptx_lib_bitcode.path());
};
@@ -187,40 +195,30 @@ impl SpirvModule {
}
let status = compiler_cmd.status()?;
assert!(status.success());
- Ok(())
+ let mut result = Vec::new();
+ compiled_binary.read_to_end(&mut result)?;
+ Ok(result)
}
- pub fn compile<'a>(
- &self,
+ fn compile_intel<'a>(
ctx: &ocl_core::Context,
dev: &ocl_core::DeviceId,
- ) -> Result<ocl_core::Program, CUresult> {
- let byte_il = unsafe {
- slice::from_raw_parts(
- self.binaries.as_ptr() as *const u8,
- self.binaries.len() * mem::size_of::<u32>(),
- )
- };
+ byte_il: &'a [u8],
+ build_options: &CString,
+ ptx_lib: Option<(&'static [u8], &'static [u8])>,
+ ) -> ocl_core::Result<ocl_core::Program> {
let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
- let main_module = match self.should_link_ptx_impl {
+ Ok(match ptx_lib {
None => {
- Self::compile_amd(byte_il, None).unwrap();
- ocl_core::build_program(
- &main_module,
- Some(&[dev]),
- &self.build_options,
- None,
- None,
- )?;
+ ocl_core::build_program(&main_module, Some(&[dev]), build_options, None, None)?;
main_module
}
- Some((ptx_impl_intel, ptx_impl_amd)) => {
- Self::compile_amd(byte_il, Some(ptx_impl_amd)).unwrap();
+ Some((ptx_impl_intel, _)) => {
let ptx_impl_prog = ocl_core::create_program_with_il(ctx, ptx_impl_intel, None)?;
ocl_core::compile_program(
&main_module,
Some(&[dev]),
- &self.build_options,
+ build_options,
&[],
&[],
None,
@@ -230,7 +228,7 @@ impl SpirvModule {
ocl_core::compile_program(
&ptx_impl_prog,
Some(&[dev]),
- &self.build_options,
+ build_options,
&[],
&[],
None,
@@ -240,15 +238,43 @@ impl SpirvModule {
ocl_core::link_program(
ctx,
Some(&[dev]),
- &self.build_options,
+ build_options,
&[&main_module, &ptx_impl_prog],
None,
None,
None,
)?
}
+ })
+ }
+
+ pub fn compile<'a>(
+ &self,
+ ctx: &ocl_core::Context,
+ dev: &ocl_core::DeviceId,
+ device_name: &str,
+ is_amd: bool,
+ ) -> Result<ocl_core::Program, CUresult> {
+ let byte_il = unsafe {
+ slice::from_raw_parts(
+ self.binaries.as_ptr() as *const u8,
+ self.binaries.len() * mem::size_of::<u32>(),
+ )
+ };
+ let ocl_program = if is_amd {
+ let binary_prog =
+ Self::compile_amd(device_name, byte_il, self.should_link_ptx_impl).unwrap();
+ ocl_core::create_program_with_binary(ctx, &[dev], &[&binary_prog[..]])?
+ } else {
+ Self::compile_intel(
+ ctx,
+ dev,
+ byte_il,
+ &self.build_options,
+ self.should_link_ptx_impl,
+ )?
};
- Ok(main_module)
+ Ok(ocl_program)
}
}
@@ -268,9 +294,12 @@ pub fn get_function(
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule {
- base: module
- .spirv
- .compile(&device.ocl_context, &device.ocl_base)?,
+ base: module.spirv.compile(
+ &device.ocl_context,
+ &device.ocl_base,
+ &device.name,
+ device.is_amd,
+ )?,
kernels: HashMap::new(),
};
entry.insert(new_module)
@@ -340,7 +369,12 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device };
- let l0_module = spirv_data.compile(&device.ocl_context, &device.ocl_base)?;
+ let l0_module = spirv_data.compile(
+ &device.ocl_context,
+ &device.ocl_base,
+ &device.name,
+ device.is_amd,
+ )?;
let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule {
base: l0_module,