diff options
Diffstat (limited to 'zluda/src/impl/function.rs')
-rw-r--r-- | zluda/src/impl/function.rs | 89 |
1 files changed, 52 insertions, 37 deletions
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs index 2a35512..2658d27 100644 --- a/zluda/src/impl/function.rs +++ b/zluda/src/impl/function.rs @@ -1,3 +1,5 @@ +use ocl_core::DeviceId; + use super::{stream::Stream, CUresult, GlobalState, HasLivenessCookie, LiveCheck}; use crate::cuda::CUfunction_attribute; use ::std::os::raw::{c_uint, c_void}; @@ -24,10 +26,9 @@ impl HasLivenessCookie for FunctionData { } pub struct FunctionData { - pub base: l0::Kernel<'static>, + pub base: ocl_core::Kernel, pub arg_size: Vec<usize>, pub use_shared_mem: bool, - pub properties: Option<Box<l0::sys::ze_kernel_properties_t>>, pub legacy_args: LegacyArguments, } @@ -50,18 +51,6 @@ impl LegacyArguments { } } -impl FunctionData { - fn get_properties(&mut self) -> Result<&l0::sys::ze_kernel_properties_t, l0::sys::ze_result_t> { - if let None = self.properties { - self.properties = Some(self.base.get_properties()?) - } - match self.properties { - Some(ref props) => Ok(props.as_ref()), - None => unsafe { hint::unreachable_unchecked() }, - } - } -} - pub fn launch_kernel( f: *mut Function, grid_dim_x: c_uint, @@ -81,13 +70,16 @@ pub fn launch_kernel( { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - GlobalState::lock_enqueue(hstream, |cmd_list, signal, wait| { + GlobalState::lock_enqueue(hstream, |queue| { let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?; if kernel_params != ptr::null_mut() { for (i, arg_size) in func.arg_size.iter().enumerate() { unsafe { - func.base - .set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))? + ocl_core::set_kernel_arg( + &func.base, + i as u32, + ocl_core::ArgVal::from_raw(*arg_size, *kernel_params.add(i), false), + )?; }; } } else { @@ -120,11 +112,15 @@ pub fn launch_kernel( for (i, arg_size) in func.arg_size.iter().enumerate() { let buffer_offset = round_up_to_multiple(offset, *arg_size); unsafe { - func.base.set_arg_raw( + ocl_core::set_kernel_arg( + &func.base, i as u32, - *arg_size, - buffer_ptr.add(buffer_offset) as *const _, - )? + ocl_core::ArgVal::from_raw( + *arg_size, + buffer_ptr.add(buffer_offset) as *const _, + false, + ), + )?; }; offset = buffer_offset + *arg_size; } @@ -134,24 +130,34 @@ pub fn launch_kernel( } if func.use_shared_mem { unsafe { - func.base.set_arg_raw( + ocl_core::set_kernel_arg( + &func.base, func.arg_size.len() as u32, - shared_mem_bytes as usize, - ptr::null(), - )? + ocl_core::ArgVal::from_raw(shared_mem_bytes as usize, ptr::null(), false), + )?; }; } - func.base - .set_group_size(block_dim_x, block_dim_y, block_dim_z)?; - func.legacy_args.reset(); + let global_dims = [ + (block_dim_x * grid_dim_x) as usize, + (block_dim_y * grid_dim_y) as usize, + (block_dim_z * grid_dim_z) as usize, + ]; unsafe { - cmd_list.append_launch_kernel( - &mut func.base, - &[grid_dim_x, grid_dim_y, grid_dim_z], - Some(signal), - wait, - )?; - } + ocl_core::enqueue_kernel::<&mut ocl_core::Event, ocl_core::Event>( + queue, + &func.base, + 3, + None, + &global_dims, + Some([ + block_dim_x as usize, + block_dim_y as usize, + block_dim_z as usize, + ]), + None, + None, + )? + }; Ok::<_, CUresult>(()) }) } @@ -171,8 +177,17 @@ pub(crate) fn get_attribute( match attrib { CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { let max_threads = GlobalState::lock_function(func, |func| { - let props = func.get_properties()?; - Ok::<_, CUresult>(props.maxSubgroupSize * props.maxNumSubgroups) + if let ocl_core::KernelWorkGroupInfoResult::WorkGroupSize(size) = + ocl_core::get_kernel_work_group_info::<ocl_core::DeviceId>( + &func.base, + unsafe { ocl_core::DeviceId::null() }, + ocl_core::KernelWorkGroupInfo::WorkGroupSize, + )? + { + Ok(size) + } else { + Err(CUresult::CUDA_ERROR_UNKNOWN) + } })??; unsafe { *pi = max_threads as i32 }; Ok(()) |