aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/function.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src/impl/function.rs')
-rw-r--r--zluda/src/impl/function.rs81
1 files changed, 46 insertions, 35 deletions
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs
index 05f864b..548936f 100644
--- a/zluda/src/impl/function.rs
+++ b/zluda/src/impl/function.rs
@@ -51,6 +51,37 @@ impl LegacyArguments {
}
}
+unsafe fn set_arg(
+ kernel: &ocl_core::Kernel,
+ arg_index: usize,
+ arg_size: usize,
+ arg_value: *const c_void,
+ is_mem: bool,
+) -> Result<(), CUresult> {
+ if is_mem {
+ let error = 0;
+ unsafe {
+ ocl_core::ffi::clSetKernelArgSVMPointer(
+ kernel.as_ptr(),
+ arg_index as u32,
+ *(arg_value as *const _),
+ )
+ };
+ if error != 0 {
+ panic!("clSetKernelArgSVMPointer");
+ }
+ } else {
+ unsafe {
+ ocl_core::set_kernel_arg(
+ kernel,
+ arg_index as u32,
+ ocl_core::ArgVal::from_raw(arg_size, arg_value, is_mem),
+ )?;
+ };
+ }
+ Ok(())
+}
+
pub fn launch_kernel(
f: *mut Function,
grid_dim_x: c_uint,
@@ -74,27 +105,7 @@ pub fn launch_kernel(
let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?;
if kernel_params != ptr::null_mut() {
for (i, &(arg_size, is_mem)) in func.arg_size.iter().enumerate() {
- if is_mem {
- let error = 0;
- unsafe {
- ocl_core::ffi::clSetKernelArgSVMPointer(
- func.base.as_ptr(),
- i as u32,
- *(*kernel_params.add(i) as *const _),
- )
- };
- if error != 0 {
- panic!("clSetKernelArgSVMPointer");
- }
- } else {
- unsafe {
- ocl_core::set_kernel_arg(
- &func.base,
- i as u32,
- ocl_core::ArgVal::from_raw(arg_size, *kernel_params.add(i), is_mem),
- )?;
- };
- }
+ unsafe { set_arg(&func.base, i, arg_size, *kernel_params.add(i), is_mem)? };
}
} else {
let mut offset = 0;
@@ -126,15 +137,13 @@ pub fn launch_kernel(
for (i, &(arg_size, is_mem)) in func.arg_size.iter().enumerate() {
let buffer_offset = round_up_to_multiple(offset, arg_size);
unsafe {
- ocl_core::set_kernel_arg(
+ set_arg(
&func.base,
- i as u32,
- ocl_core::ArgVal::from_raw(
- arg_size,
- buffer_ptr.add(buffer_offset) as *const _,
- is_mem,
- ),
- )?;
+ i,
+ arg_size,
+ buffer_ptr.add(buffer_offset) as *const _,
+ is_mem,
+ )?
};
offset = buffer_offset + arg_size;
}
@@ -144,11 +153,13 @@ pub fn launch_kernel(
}
if func.use_shared_mem {
unsafe {
- ocl_core::set_kernel_arg(
+ set_arg(
&func.base,
- func.arg_size.len() as u32,
- ocl_core::ArgVal::from_raw(shared_mem_bytes as usize, ptr::null(), false),
- )?;
+ func.arg_size.len(),
+ shared_mem_bytes as usize,
+ ptr::null(),
+ false,
+ )?
};
}
let global_dims = [
@@ -192,9 +203,9 @@ pub(crate) fn get_attribute(
CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
let max_threads = GlobalState::lock_function(func, |func| {
if let ocl_core::KernelWorkGroupInfoResult::WorkGroupSize(size) =
- ocl_core::get_kernel_work_group_info::<ocl_core::DeviceId>(
+ ocl_core::get_kernel_work_group_info::<()>(
&func.base,
- unsafe { ocl_core::DeviceId::null() },
+ (),
ocl_core::KernelWorkGroupInfo::WorkGroupSize,
)?
{