aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-24 01:54:16 +0200
committerAndrzej Janik <[email protected]>2020-09-24 01:54:16 +0200
commit3f41f21acb51f7a1d305630dc2a4e5c5df5e4a83 (patch)
treed6d2fb03034b00b8b711bed08802bf0e2dbb275e
parent03005140dde6c45a47d1fe03a183d76af38b7a12 (diff)
downloadZLUDA-3f41f21acb51f7a1d305630dc2a4e5c5df5e4a83.tar.gz
ZLUDA-3f41f21acb51f7a1d305630dc2a4e5c5df5e4a83.zip
Implement more host code, moving execution further
-rw-r--r--level_zero/src/ze.rs6
-rw-r--r--notcuda/src/cuda.rs2
-rw-r--r--notcuda/src/impl/device.rs48
-rw-r--r--notcuda/src/impl/export_table.rs25
-rw-r--r--notcuda/src/impl/mod.rs15
-rw-r--r--notcuda/src/impl/module.rs75
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/pred_not.ptx28
-rw-r--r--ptx/src/test/spirv_run/pred_not.spvtxt78
9 files changed, 241 insertions, 37 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs
index cee736c..16b9130 100644
--- a/level_zero/src/ze.rs
+++ b/level_zero/src/ze.rs
@@ -118,6 +118,12 @@ impl Device {
Ok(props)
}
+ pub fn get_compute_properties(&self) -> Result<Box<sys::ze_device_compute_properties_t>> {
+ let mut props = Box::new(unsafe { mem::zeroed::<sys::ze_device_compute_properties_t>() });
+ check! { sys::zeDeviceGetComputeProperties(self.0, props.as_mut()) };
+ Ok(props)
+ }
+
pub unsafe fn mem_alloc_device(
&mut self,
ctx: &mut Context,
diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs
index 3267042..122f0da 100644
--- a/notcuda/src/cuda.rs
+++ b/notcuda/src/cuda.rs
@@ -2501,7 +2501,7 @@ pub extern "C" fn cuModuleGetFunction(
hmod: CUmodule,
name: *const ::std::os::raw::c_char,
) -> CUresult {
- r#impl::unimplemented()
+ r#impl::module::get_function(hfunc.decuda(), hmod.decuda(), name).encuda()
}
#[no_mangle]
diff --git a/notcuda/src/impl/device.rs b/notcuda/src/impl/device.rs
index 8a8f2f8..db39efd 100644
--- a/notcuda/src/impl/device.rs
+++ b/notcuda/src/impl/device.rs
@@ -1,4 +1,4 @@
-use super::{context, CUresult, Error};
+use super::{context, transmute_lifetime, CUresult, Error};
use crate::cuda;
use cuda::{CUdevice_attribute, CUuuid_st};
use std::{
@@ -25,6 +25,7 @@ pub struct Device {
properties: Option<Box<l0::sys::ze_device_properties_t>>,
image_properties: Option<Box<l0::sys::ze_device_image_properties_t>>,
memory_properties: Option<Vec<l0::sys::ze_device_memory_properties_t>>,
+ compute_properties: Option<Box<l0::sys::ze_device_compute_properties_t>>,
}
unsafe impl Send for Device {}
@@ -48,6 +49,7 @@ impl Device {
properties: None,
image_properties: None,
memory_properties: None,
+ compute_properties: None,
})
}
@@ -80,6 +82,16 @@ impl Device {
Err(e) => Err(e),
}
}
+
+ fn get_compute_properties(&mut self) -> l0::Result<&l0::sys::ze_device_compute_properties_t> {
+ if let Some(ref prop) = self.compute_properties {
+ return Ok(prop);
+ }
+ match self.base.get_compute_properties() {
+ Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)),
+ Err(e) => Err(e),
+ }
+ }
}
pub fn init(driver: &l0::Driver) -> l0::Result<()> {
@@ -166,10 +178,6 @@ pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult>
Ok(())
}
-unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
- mem::transmute(t)
-}
-
pub fn total_mem_v2(bytes: *mut usize, dev: Index) -> Result<(), CUresult> {
if bytes == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
@@ -232,6 +240,34 @@ pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Re
.maxImageDims1D,
c_int::max_value() as u32,
) as c_int,
+ CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => {
+ let props = dev.get_compute_properties().map_err(Error::L0)?;
+ cmp::max(i32::max_value() as u32, props.maxGroupCountX) as i32
+ }
+ CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => {
+ let props = dev.get_compute_properties().map_err(Error::L0)?;
+ cmp::max(i32::max_value() as u32, props.maxGroupCountY) as i32
+ }
+ CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => {
+ let props = dev.get_compute_properties().map_err(Error::L0)?;
+ cmp::max(i32::max_value() as u32, props.maxGroupCountZ) as i32
+ }
+ CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => {
+ let props = dev.get_compute_properties().map_err(Error::L0)?;
+ cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32
+ }
+ CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => {
+ let props = dev.get_compute_properties().map_err(Error::L0)?;
+ cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32
+ }
+ CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => {
+ let props = dev.get_compute_properties().map_err(Error::L0)?;
+ cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32
+ }
+ CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
+ let props = dev.get_compute_properties().map_err(Error::L0)?;
+ cmp::max(i32::max_value() as u32, props.maxTotalGroupSize) as i32
+ }
_ => {
// TODO: support more attributes for CUDA runtime
/*
@@ -311,8 +347,6 @@ pub fn primary_ctx_retain(pctx: *mut *mut context::Context, dev: Index) -> Resul
mod tests {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
- use crate::cuda::CUuuid;
- use std::{ffi::c_void, mem, ptr};
cuda_driver_test!(primary_ctx_default_inactive);
diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs
index 233c496..9a6d72c 100644
--- a/notcuda/src/impl/export_table.rs
+++ b/notcuda/src/impl/export_table.rs
@@ -66,12 +66,11 @@ static TOOLS_RUNTIME_CALLBACK_HOOKS_VTABLE: [VTableEntry; TOOLS_RUNTIME_CALLBACK
ptr: runtime_callback_hooks_fn5 as *const (),
},
];
-static mut TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE: [u8; 512] = [0; 512];
+static mut TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE: [usize; 512] = [0; 512];
-unsafe extern "C" fn runtime_callback_hooks_fn1(ptr: *mut *mut u8, size: *mut usize) -> *mut u8 {
+unsafe extern "C" fn runtime_callback_hooks_fn1(ptr: *mut *mut usize, size: *mut usize) {
*ptr = TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE.as_mut_ptr();
*size = TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE.len();
- return TOOLS_RUNTIME_CALLBACK_HOOKS_FN1_SPACE.as_mut_ptr();
}
static mut TOOLS_RUNTIME_CALLBACK_HOOKS_FN5_SPACE: [u8; 2] = [0; 2];
@@ -198,9 +197,14 @@ struct FatbinFileHeader {
unsafe extern "C" fn get_module_from_cubin(
result: *mut CUmodule,
fatbinc_wrapper: *const FatbincWrapper,
- _: *mut c_void,
- _: *mut c_void,
+ ptr1: *mut c_void,
+ ptr2: *mut c_void,
) -> CUresult {
+ // Not sure what those twoparameters are actually used for,
+ // they are somehow involved in __cudaRegisterHostVar
+ if ptr1 != ptr::null_mut() || ptr2 != ptr::null_mut() {
+ return CUresult::CUDA_ERROR_NOT_SUPPORTED;
+ }
if result == ptr::null_mut()
|| (*fatbinc_wrapper).magic != FATBINC_MAGIC
|| (*fatbinc_wrapper).version != FATBINC_VERSION
@@ -208,11 +212,6 @@ unsafe extern "C" fn get_module_from_cubin(
return CUresult::CUDA_ERROR_INVALID_VALUE;
}
let result = result.decuda();
- let mut dev_count = 0;
- let cu_result = device::get_count(&mut dev_count);
- if cu_result != CUresult::CUDA_SUCCESS {
- return cu_result;
- }
let fatbin_header = (*fatbinc_wrapper).data;
if (*fatbin_header).magic != FATBIN_MAGIC || (*fatbin_header).version != FATBIN_VERSION {
return CUresult::CUDA_ERROR_INVALID_VALUE;
@@ -235,7 +234,7 @@ unsafe extern "C" fn get_module_from_cubin(
},
Err(_) => continue,
};
- let module = module::Module::compile(kernel_text_string, dev_count as usize);
+ let module = module::ModuleData::compile_spirv(kernel_text_string);
match module {
Ok(module) => {
*result = Box::into_raw(Box::new(module));
@@ -310,7 +309,7 @@ unsafe extern "C" fn context_local_storage_ctor(
}
fn context_local_storage_ctor_impl(
- cu_ctx: *mut context::Context,
+ mut cu_ctx: *mut context::Context,
mgr: *mut cuda_impl::rt::ContextStateManager,
ctx_state: *mut cuda_impl::rt::ContextState,
dtor_cb: Option<
@@ -322,7 +321,7 @@ fn context_local_storage_ctor_impl(
>,
) -> Result<(), CUresult> {
if cu_ctx == ptr::null_mut() {
- return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED);
+ context::get_current(&mut cu_ctx)?;
}
unsafe { &*cu_ctx }
.as_ref()
diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs
index 7813532..c37b85d 100644
--- a/notcuda/src/impl/mod.rs
+++ b/notcuda/src/impl/mod.rs
@@ -1,5 +1,5 @@
-use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUresult, CUmodule};
-use std::{ffi::c_void, mem::ManuallyDrop, os::raw::c_int, sync::Mutex};
+use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunction, CUmod_st, CUmodule, CUresult};
+use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
#[cfg(test)]
#[macro_use]
@@ -206,6 +206,10 @@ pub fn init() -> l0::Result<()> {
Ok(())
}
+unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
+ mem::transmute(t)
+}
+
pub fn driver_get_version() -> c_int {
i32::max_value()
}
@@ -234,7 +238,10 @@ impl Decuda<*mut c_void> for CUdeviceptr {
}
}
-impl<'a> CudaRepr for CUmodule {
- type Impl = *mut module::Module;
+impl<'a> CudaRepr for CUmod_st {
+ type Impl = module::Module;
}
+impl<'a> CudaRepr for CUfunction {
+ type Impl = *mut module::Function;
+}
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index 4b664b5..06d050d 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -1,8 +1,14 @@
+use std::{ffi::c_void, ffi::CStr, mem, os::raw::c_char, ptr, slice, sync::Mutex};
+
+use super::{transmute_lifetime, CUresult};
use ptx;
-pub struct Module {
- spirv_code: Vec<u32>,
- compiled_code: Vec<Option<Vec<u8>>>, // size as big as the number of devices
+use super::context;
+
+pub type Module = Mutex<ModuleData>;
+
+pub struct ModuleData {
+ base: l0::Module,
}
pub enum ModuleCompileError<'a> {
@@ -10,21 +16,35 @@ pub enum ModuleCompileError<'a> {
Vec<ptx::ast::PtxError>,
Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
),
- Compile(ptx::SpirvError),
+ Compile(ptx::TranslateError),
+ L0(l0::sys::ze_result_t),
+ CUDA(CUresult),
}
impl<'a> ModuleCompileError<'a> {
pub fn get_build_log(&self) {}
}
-impl<'a> From<ptx::SpirvError> for ModuleCompileError<'a> {
- fn from(err: ptx::SpirvError) -> Self {
+impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> {
+ fn from(err: ptx::TranslateError) -> Self {
ModuleCompileError::Compile(err)
}
}
-impl Module {
- pub fn compile(ptx_text: &str, devices: usize) -> Result<Self, ModuleCompileError> {
+impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> {
+ fn from(err: l0::sys::ze_result_t) -> Self {
+ ModuleCompileError::L0(err)
+ }
+}
+
+impl<'a> From<CUresult> for ModuleCompileError<'a> {
+ fn from(err: CUresult) -> Self {
+ ModuleCompileError::CUDA(err)
+ }
+}
+
+impl ModuleData {
+ pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
let ast = match ast {
@@ -33,9 +53,40 @@ impl Module {
Ok(ast) => ast,
};
let spirv = ptx::to_spirv(ast)?;
- Ok(Self {
- spirv_code: spirv,
- compiled_code: vec![None; devices],
- })
+ let byte_il = unsafe {
+ slice::from_raw_parts::<u8>(
+ spirv.as_ptr() as *const _,
+ spirv.len() * mem::size_of::<u32>(),
+ )
+ };
+ let module = super::device::with_current_exclusive(|dev| {
+ l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
+ });
+ match module {
+ Ok(Ok(module)) => Ok(Mutex::new(Self { base: module })),
+ Ok(Err(err)) => Err(ModuleCompileError::from(err)),
+ Err(err) => Err(ModuleCompileError::from(err)),
+ }
+ }
+}
+
+pub struct Function {
+ base: l0::Kernel<'static>,
+}
+
+pub fn get_function(
+ hfunc: *mut *mut Function,
+ hmod: *mut Module,
+ name: *const c_char,
+) -> Result<(), CUresult> {
+ if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
+ let name = unsafe { CStr::from_ptr(name) };
+ let kernel = unsafe { &*hmod }
+ .try_lock()
+ .map(|module| l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name))
+ .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
+ unsafe { *hfunc = Box::into_raw(Box::new(Function { base: kernel })) };
+ Ok(())
}
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 06843f0..78c3375 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -64,6 +64,7 @@ test_ptx!(reg_local, [12u64], [13u64]);
test_ptx!(mov_address, [0xDEADu64], [0u64]);
test_ptx!(b64tof64, [111u64], [111u64]);
test_ptx!(implicit_param, [34u32], [34u32]);
+test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/pred_not.ptx b/ptx/src/test/spirv_run/pred_not.ptx
new file mode 100644
index 0000000..e058470
--- /dev/null
+++ b/ptx/src/test/spirv_run/pred_not.ptx
@@ -0,0 +1,28 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry pred_not(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+ .reg .u64 temp2;
+ .reg .u64 temp3;
+ .reg .pred pred;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ ld.u64 temp2, [in_addr + 8];
+ setp.lt.u64 pred, temp, temp2;
+ not.pred pred, pred;
+ @pred mov.u64 temp3, 1;
+ @!pred mov.u64 temp3, 2;
+ st.u64 [out_addr], temp3;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/pred_not.spvtxt b/ptx/src/test/spirv_run/pred_not.spvtxt
new file mode 100644
index 0000000..410b1e4
--- /dev/null
+++ b/ptx/src/test/spirv_run/pred_not.spvtxt
@@ -0,0 +1,78 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ OpCapability Float64
+ %44 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "pred_not"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %47 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_8 = OpConstant %ulong 8
+ %true = OpConstantTrue %bool
+ %false = OpConstantFalse %bool
+ %ulong_1 = OpConstant %ulong 1
+ %ulong_2 = OpConstant %ulong 2
+ %1 = OpFunction %void None %47
+ %14 = OpFunctionParameter %ulong
+ %15 = OpFunctionParameter %ulong
+ %42 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ %9 = OpVariable %_ptr_Function_bool Function
+ OpStore %2 %14
+ OpStore %3 %15
+ %17 = OpLoad %ulong %2
+ %16 = OpCopyObject %ulong %17
+ OpStore %4 %16
+ %19 = OpLoad %ulong %3
+ %18 = OpCopyObject %ulong %19
+ OpStore %5 %18
+ %21 = OpLoad %ulong %4
+ %39 = OpConvertUToPtr %_ptr_Generic_ulong %21
+ %20 = OpLoad %ulong %39
+ OpStore %6 %20
+ %23 = OpLoad %ulong %4
+ %36 = OpIAdd %ulong %23 %ulong_8
+ %40 = OpConvertUToPtr %_ptr_Generic_ulong %36
+ %22 = OpLoad %ulong %40
+ OpStore %7 %22
+ %25 = OpLoad %ulong %6
+ %26 = OpLoad %ulong %7
+ %24 = OpULessThan %bool %25 %26
+ OpStore %9 %24
+ %28 = OpLoad %bool %9
+ %27 = OpSelect %bool %28 %false %true
+ OpStore %9 %27
+ %29 = OpLoad %bool %9
+ OpBranchConditional %29 %10 %11
+ %10 = OpLabel
+ %30 = OpCopyObject %ulong %ulong_1
+ OpStore %8 %30
+ OpBranch %11
+ %11 = OpLabel
+ %31 = OpLoad %bool %9
+ OpBranchConditional %31 %13 %12
+ %12 = OpLabel
+ %32 = OpCopyObject %ulong %ulong_2
+ OpStore %8 %32
+ OpBranch %13
+ %13 = OpLabel
+ %33 = OpLoad %ulong %5
+ %34 = OpLoad %ulong %8
+ %41 = OpConvertUToPtr %_ptr_Generic_ulong %33
+ OpStore %41 %34
+ OpReturn
+ OpFunctionEnd