diff options
Diffstat (limited to 'zluda/src/impl/function.rs')
-rw-r--r-- | zluda/src/impl/function.rs | 81 |
1 files changed, 46 insertions, 35 deletions
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs index 05f864b..548936f 100644 --- a/zluda/src/impl/function.rs +++ b/zluda/src/impl/function.rs @@ -51,6 +51,37 @@ impl LegacyArguments { } } +unsafe fn set_arg( + kernel: &ocl_core::Kernel, + arg_index: usize, + arg_size: usize, + arg_value: *const c_void, + is_mem: bool, +) -> Result<(), CUresult> { + if is_mem { + let error = 0; + unsafe { + ocl_core::ffi::clSetKernelArgSVMPointer( + kernel.as_ptr(), + arg_index as u32, + *(arg_value as *const _), + ) + }; + if error != 0 { + panic!("clSetKernelArgSVMPointer"); + } + } else { + unsafe { + ocl_core::set_kernel_arg( + kernel, + arg_index as u32, + ocl_core::ArgVal::from_raw(arg_size, arg_value, is_mem), + )?; + }; + } + Ok(()) +} + pub fn launch_kernel( f: *mut Function, grid_dim_x: c_uint, @@ -74,27 +105,7 @@ pub fn launch_kernel( let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?; if kernel_params != ptr::null_mut() { for (i, &(arg_size, is_mem)) in func.arg_size.iter().enumerate() { - if is_mem { - let error = 0; - unsafe { - ocl_core::ffi::clSetKernelArgSVMPointer( - func.base.as_ptr(), - i as u32, - *(*kernel_params.add(i) as *const _), - ) - }; - if error != 0 { - panic!("clSetKernelArgSVMPointer"); - } - } else { - unsafe { - ocl_core::set_kernel_arg( - &func.base, - i as u32, - ocl_core::ArgVal::from_raw(arg_size, *kernel_params.add(i), is_mem), - )?; - }; - } + unsafe { set_arg(&func.base, i, arg_size, *kernel_params.add(i), is_mem)? }; } } else { let mut offset = 0; @@ -126,15 +137,13 @@ pub fn launch_kernel( for (i, &(arg_size, is_mem)) in func.arg_size.iter().enumerate() { let buffer_offset = round_up_to_multiple(offset, arg_size); unsafe { - ocl_core::set_kernel_arg( + set_arg( &func.base, - i as u32, - ocl_core::ArgVal::from_raw( - arg_size, - buffer_ptr.add(buffer_offset) as *const _, - is_mem, - ), - )?; + i, + arg_size, + buffer_ptr.add(buffer_offset) as *const _, + is_mem, + )? }; offset = buffer_offset + arg_size; } @@ -144,11 +153,13 @@ pub fn launch_kernel( } if func.use_shared_mem { unsafe { - ocl_core::set_kernel_arg( + set_arg( &func.base, - func.arg_size.len() as u32, - ocl_core::ArgVal::from_raw(shared_mem_bytes as usize, ptr::null(), false), - )?; + func.arg_size.len(), + shared_mem_bytes as usize, + ptr::null(), + false, + )? }; } let global_dims = [ @@ -192,9 +203,9 @@ pub(crate) fn get_attribute( CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { let max_threads = GlobalState::lock_function(func, |func| { if let ocl_core::KernelWorkGroupInfoResult::WorkGroupSize(size) = - ocl_core::get_kernel_work_group_info::<ocl_core::DeviceId>( + ocl_core::get_kernel_work_group_info::<()>( &func.base, - unsafe { ocl_core::DeviceId::null() }, + (), ocl_core::KernelWorkGroupInfo::WorkGroupSize, )? { |