aboutsummaryrefslogtreecommitdiffhomepage
path: root/notcuda/src/impl/module.rs
diff options
context:
space:
mode:
Diffstat (limited to 'notcuda/src/impl/module.rs')
-rw-r--r--notcuda/src/impl/module.rs8
1 files changed, 6 insertions, 2 deletions
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index 4422107..e19d8de 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -110,7 +110,6 @@ pub fn get_function(
entry.insert(new_module)
}
};
- //let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
let kernel = match compiled_module.kernels.entry(name) {
hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
hash_map::Entry::Vacant(entry) => {
@@ -121,8 +120,13 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
- let kernel =
+ let mut kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
+ kernel.set_indirect_access(
+ l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
+ | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
+ | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED
+ )?;
entry.insert(Box::new(Function::new(FunctionData {
base: kernel,
arg_size: kernel_info.arguments_sizes.clone(),