aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda
diff options
context:
space:
mode:
Diffstat (limited to 'zluda')
-rw-r--r--zluda/src/cuda.rs46
-rw-r--r--zluda/src/impl/export_table.rs7
-rw-r--r--zluda/src/impl/function.rs12
-rw-r--r--zluda/src/impl/mod.rs10
-rw-r--r--zluda/src/impl/module.rs232
5 files changed, 284 insertions, 23 deletions
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs
index 634c0df..68e5a80 100644
--- a/zluda/src/cuda.rs
+++ b/zluda/src/cuda.rs
@@ -2454,7 +2454,7 @@ pub extern "system" fn cuModuleLoad(
module: *mut CUmodule,
fname: *const ::std::os::raw::c_char,
) -> CUresult {
- unsafe { hipModuleLoad(module as _, fname as _).into() }
+ r#impl::module::load(module, fname).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2462,7 +2462,7 @@ pub extern "system" fn cuModuleLoadData(
module: *mut CUmodule,
image: *const ::std::os::raw::c_void,
) -> CUresult {
- unsafe { hipModuleLoadData(module as _, image as _).into() }
+ r#impl::module::load_data(module, image).encuda()
}
// TODO: parse jit options
@@ -2474,16 +2474,7 @@ pub extern "system" fn cuModuleLoadDataEx(
options: *mut CUjit_option,
optionValues: *mut *mut ::std::os::raw::c_void,
) -> CUresult {
- unsafe {
- hipModuleLoadDataEx(
- module as _,
- image as _,
- numOptions,
- options as _,
- optionValues,
- )
- .into()
- }
+ r#impl::module::load_data(module, image).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -3710,7 +3701,22 @@ pub extern "system" fn cuLaunchKernel(
kernelParams: *mut *mut ::std::os::raw::c_void,
extra: *mut *mut ::std::os::raw::c_void,
) -> CUresult {
- todo!()
+ unsafe {
+ hipModuleLaunchKernel(
+ f as _,
+ gridDimX,
+ gridDimY,
+ gridDimZ,
+ blockDimX,
+ blockDimY,
+ blockDimZ,
+ sharedMemBytes,
+ hStream as _,
+ kernelParams,
+ extra,
+ )
+ }
+ .into()
}
// TODO: implement default stream semantics
@@ -3728,7 +3734,19 @@ pub extern "system" fn cuLaunchKernel_ptsz(
kernelParams: *mut *mut ::std::os::raw::c_void,
extra: *mut *mut ::std::os::raw::c_void,
) -> CUresult {
- todo!()
+ cuLaunchKernel(
+ f,
+ gridDimX,
+ gridDimY,
+ gridDimZ,
+ blockDimX,
+ blockDimY,
+ blockDimZ,
+ sharedMemBytes,
+ hStream,
+ kernelParams,
+ extra,
+ )
}
#[cfg_attr(not(test), no_mangle)]
diff --git a/zluda/src/impl/export_table.rs b/zluda/src/impl/export_table.rs
index 5734f05..c95588c 100644
--- a/zluda/src/impl/export_table.rs
+++ b/zluda/src/impl/export_table.rs
@@ -12,7 +12,7 @@ use crate::{
cuda_impl,
};
-use super::{device, Decuda, Encuda};
+use super::{device, module, Decuda, Encuda};
use std::collections::HashMap;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@@ -253,20 +253,17 @@ unsafe extern "system" fn get_module_from_cubin(
},
Err(_) => continue,
};
- todo!()
- /*
let module = module::SpirvModule::new(kernel_text_string);
match module {
Ok(module) => {
match module::load_data_impl(result, module) {
Ok(()) => {}
- Err(err) => return err,
+ Err(err) => return err.into(),
}
return CUresult::CUDA_SUCCESS;
}
Err(_) => continue,
}
- */
}
CUresult::CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE
}
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs
index 8470620..c5ea964 100644
--- a/zluda/src/impl/function.rs
+++ b/zluda/src/impl/function.rs
@@ -1,7 +1,7 @@
-use hip_runtime_sys::{hipError_t, hipFuncGetAttributes};
+use hip_runtime_sys::{hipError_t, hipFuncGetAttributes, hipLaunchKernel, hipModuleLaunchKernel};
use super::{CUresult, HasLivenessCookie, LiveCheck};
-use crate::cuda::{CUfunction, CUfunction_attribute};
+use crate::cuda::{CUfunction, CUfunction_attribute, CUstream};
use ::std::os::raw::{c_uint, c_void};
use std::{mem, ptr};
@@ -19,8 +19,12 @@ pub(crate) fn get_attribute(
return err;
}
let value = match cu_attrib {
- CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => hip_attrib.maxThreadsPerBlock,
- CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES => hip_attrib.sharedSizeBytes as i32,
+ CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
+ hip_attrib.maxThreadsPerBlock
+ }
+ CUfunction_attribute::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES => {
+ hip_attrib.sharedSizeBytes as i32
+ }
_ => return hipError_t::hipErrorInvalidValue,
};
unsafe { *pi = value };
diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs
index 09908bb..e0d19ae 100644
--- a/zluda/src/impl/mod.rs
+++ b/zluda/src/impl/mod.rs
@@ -1,3 +1,5 @@
+use hip_runtime_sys::hipError_t;
+
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
use std::{
ffi::c_void,
@@ -17,6 +19,7 @@ pub mod function;
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
pub(crate) mod os;
+pub(crate) mod module;
#[cfg(debug_assertions)]
pub fn unimplemented() -> CUresult {
@@ -180,6 +183,13 @@ impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1,
}
}
+impl Encuda for hipError_t {
+ type To = CUresult;
+ fn encuda(self: Self) -> Self::To {
+ self.into()
+ }
+}
+
unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
mem::transmute(t)
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs
new file mode 100644
index 0000000..463312c
--- /dev/null
+++ b/zluda/src/impl/module.rs
@@ -0,0 +1,232 @@
+use std::borrow::Cow;
+use std::collections::HashMap;
+use std::ffi::{CStr, CString};
+use std::fs::File;
+use std::io::{self, Read, Write};
+use std::ops::Add;
+use std::os::raw::c_char;
+use std::path::PathBuf;
+use std::process::Command;
+use std::{fs, mem, ptr, slice};
+
+use hip_runtime_sys::{
+ hipCtxGetCurrent, hipCtxGetDevice, hipDeviceGetAttribute, hipDeviceGetName, hipError_t,
+ hipGetDeviceProperties, hipGetStreamDeviceId, hipModuleLoadData,
+};
+use tempfile::NamedTempFile;
+
+use crate::cuda::CUmodule;
+
+pub struct SpirvModule {
+ pub binaries: Vec<u32>,
+ pub kernel_info: HashMap<String, ptx::KernelInfo>,
+ pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>,
+ pub build_options: CString,
+}
+
+impl SpirvModule {
+ pub fn new_raw<'a>(text: *const c_char) -> Result<Self, hipError_t> {
+ let u8_text = unsafe { CStr::from_ptr(text) };
+ let ptx_text = u8_text
+ .to_str()
+ .map_err(|_| hipError_t::hipErrorInvalidImage)?;
+ Self::new(ptx_text)
+ }
+
+ pub fn new<'a>(ptx_text: &str) -> Result<Self, hipError_t> {
+ let mut errors = Vec::new();
+ let ast = ptx::ModuleParser::new()
+ .parse(&mut errors, ptx_text)
+ .map_err(|_| hipError_t::hipErrorInvalidImage)?;
+ if errors.len() > 0 {
+ return Err(hipError_t::hipErrorInvalidImage);
+ }
+ let spirv_module =
+ ptx::to_spirv_module(ast).map_err(|_| hipError_t::hipErrorInvalidImage)?;
+ Ok(SpirvModule {
+ binaries: spirv_module.assemble(),
+ kernel_info: spirv_module.kernel_info,
+ should_link_ptx_impl: spirv_module.should_link_ptx_impl,
+ build_options: spirv_module.build_options,
+ })
+ }
+}
+
+pub(crate) fn load(module: *mut CUmodule, fname: *const i8) -> Result<(), hipError_t> {
+ let length = (0..)
+ .position(|i| unsafe { *fname.add(i) == 0 })
+ .ok_or(hipError_t::hipErrorInvalidValue)?;
+ let file_name = CStr::from_bytes_with_nul(unsafe { slice::from_raw_parts(fname as _, length) })
+ .map_err(|_| hipError_t::hipErrorInvalidValue)?;
+ let valid_file_name = file_name
+ .to_str()
+ .map_err(|_| hipError_t::hipErrorInvalidValue)?;
+ let mut file = File::open(valid_file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?;
+ let mut file_buffer = Vec::new();
+ file.read_to_end(&mut file_buffer)
+ .map_err(|_| hipError_t::hipErrorUnknown)?;
+ drop(file);
+ load_data(module, file_buffer.as_ptr() as _)
+}
+
+pub(crate) fn load_data(
+ module: *mut CUmodule,
+ image: *const std::ffi::c_void,
+) -> Result<(), hipError_t> {
+ let spirv_data = SpirvModule::new_raw(image as *const _)?;
+ load_data_impl(module, spirv_data)
+}
+
+pub fn load_data_impl(pmod: *mut CUmodule, spirv_data: SpirvModule) -> Result<(), hipError_t> {
+ let mut dev = 0;
+ let err = unsafe { hipCtxGetDevice(&mut dev) };
+ if err != hipError_t::hipSuccess {
+ return Err(err);
+ }
+ let mut props = unsafe { mem::zeroed() };
+ let err = unsafe { hipGetDeviceProperties(&mut props, dev) };
+ if err != hipError_t::hipSuccess {
+ return Err(err);
+ }
+ let gcn_arch_slice =
+ unsafe { slice::from_raw_parts(props.gcnArchName.as_ptr() as _, props.gcnArchName.len()) };
+ let name = if let Ok(Ok(name)) = CStr::from_bytes_with_nul(gcn_arch_slice).map(|x| x.to_str()) {
+ name
+ } else {
+ return Err(hipError_t::hipErrorUnknown);
+ };
+ let arch_binary = compile_amd(
+ name,
+ &spirv_data.binaries[..],
+ spirv_data.should_link_ptx_impl,
+ )
+ .map_err(|_| hipError_t::hipErrorUnknown)?;
+ let err = unsafe { hipModuleLoadData(pmod as _, arch_binary.as_ptr() as _) };
+ if err != hipError_t::hipSuccess {
+ return Err(err);
+ }
+ Ok(())
+}
+
+const LLVM_SPIRV: &'static str = "/home/vosen/amd/llvm-project/build/bin/llvm-spirv";
+const AMDGPU: &'static str = "/opt/amdgpu-pro/";
+const AMDGPU_TARGET: &'static str = "amdgcn-amd-amdhsa";
+const AMDGPU_BITCODE: [&'static str; 8] = [
+ "opencl.bc",
+ "ocml.bc",
+ "ockl.bc",
+ "oclc_correctly_rounded_sqrt_off.bc",
+ "oclc_daz_opt_on.bc",
+ "oclc_finite_only_off.bc",
+ "oclc_unsafe_math_off.bc",
+ "oclc_wavefrontsize64_off.bc",
+];
+const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
+
+fn compile_amd(
+ device_name: &str,
+ spirv_il: &[u32],
+ ptx_lib: Option<(&'static [u8], &'static [u8])>,
+) -> io::Result<Vec<u8>> {
+ use std::env;
+ let dir = tempfile::tempdir()?;
+ let mut spirv = NamedTempFile::new_in(&dir)?;
+ let llvm = NamedTempFile::new_in(&dir)?;
+ let spirv_il_u8 = unsafe {
+ slice::from_raw_parts(
+ spirv_il.as_ptr() as *const u8,
+ spirv_il.len() * mem::size_of::<u32>(),
+ )
+ };
+ spirv.write_all(spirv_il_u8)?;
+ let llvm_spirv_path = match env::var("LLVM_SPIRV") {
+ Ok(path) => Cow::Owned(path),
+ Err(_) => Cow::Borrowed(LLVM_SPIRV),
+ };
+ let to_llvm_cmd = Command::new(&*llvm_spirv_path)
+ .arg("-r")
+ .arg("-o")
+ .arg(llvm.path())
+ .arg(spirv.path())
+ .status()?;
+ assert!(to_llvm_cmd.success());
+ let linked_binary = NamedTempFile::new_in(&dir)?;
+ let mut llvm_link = PathBuf::from(AMDGPU);
+ llvm_link.push("bin");
+ llvm_link.push("llvm-link");
+ let mut linker_cmd = Command::new(&llvm_link);
+ linker_cmd
+ .arg("--only-needed")
+ .arg("-o")
+ .arg(linked_binary.path())
+ .arg(llvm.path())
+ .args(get_bitcode_paths(device_name));
+ if cfg!(debug_assertions) {
+ linker_cmd.arg("-v");
+ }
+ let status = linker_cmd.status()?;
+ assert!(status.success());
+ let mut ptx_lib_bitcode = NamedTempFile::new_in(&dir)?;
+ let compiled_binary = NamedTempFile::new_in(&dir)?;
+ let mut cland_exe = PathBuf::from(AMDGPU);
+ cland_exe.push("bin");
+ cland_exe.push("clang");
+ let mut compiler_cmd = Command::new(&cland_exe);
+ compiler_cmd
+ .arg(format!("-mcpu={}", device_name))
+ .arg("-nogpulib")
+ .arg("-mno-wavefrontsize64")
+ .arg("-O3")
+ .arg("-Xlinker")
+ .arg("--no-undefined")
+ .arg("-target")
+ .arg(AMDGPU_TARGET)
+ .arg("-o")
+ .arg(compiled_binary.path())
+ .arg("-x")
+ .arg("ir")
+ .arg(linked_binary.path());
+ if let Some((_, bitcode)) = ptx_lib {
+ ptx_lib_bitcode.write_all(bitcode)?;
+ compiler_cmd.arg(ptx_lib_bitcode.path());
+ };
+ if cfg!(debug_assertions) {
+ compiler_cmd.arg("-v");
+ }
+ let status = compiler_cmd.status()?;
+ assert!(status.success());
+ let mut result = Vec::new();
+ let compiled_bin_path = compiled_binary.path();
+ let mut compiled_binary = File::open(compiled_bin_path)?;
+ compiled_binary.read_to_end(&mut result)?;
+ let mut persistent = PathBuf::from("/tmp/zluda");
+ std::fs::create_dir_all(&persistent)?;
+ persistent.push(compiled_bin_path.file_name().unwrap());
+ std::fs::copy(compiled_bin_path, persistent)?;
+ Ok(result)
+}
+
+fn get_bitcode_paths(device_name: &str) -> impl Iterator<Item = PathBuf> {
+ let generic_paths = AMDGPU_BITCODE.iter().map(|x| {
+ let mut path = PathBuf::from(AMDGPU);
+ path.push("amdgcn");
+ path.push("bitcode");
+ path.push(x);
+ path
+ });
+ let suffix = if let Some(suffix_idx) = device_name.find(':') {
+ suffix_idx
+ } else {
+ device_name.len()
+ };
+ let mut additional_path = PathBuf::from(AMDGPU);
+ additional_path.push("amdgcn");
+ additional_path.push("bitcode");
+ additional_path.push(format!(
+ "{}{}{}",
+ AMDGPU_BITCODE_DEVICE_PREFIX,
+ &device_name[3..suffix],
+ ".bc"
+ ));
+ generic_paths.chain(std::iter::once(additional_path))
+}