summaryrefslogtreecommitdiffhomepage
path: root/notcuda
diff options
context:
space:
mode:
Diffstat (limited to 'notcuda')
-rw-r--r--notcuda/src/cuda.rs2
-rw-r--r--notcuda/src/impl/device.rs5
-rw-r--r--notcuda/src/impl/export_table.rs15
-rw-r--r--notcuda/src/impl/module.rs8
4 files changed, 24 insertions, 6 deletions
diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs
index 335da4a..a528981 100644
--- a/notcuda/src/cuda.rs
+++ b/notcuda/src/cuda.rs
@@ -2281,7 +2281,7 @@ pub extern "C" fn cuDevicePrimaryCtxRelease(dev: CUdevice) -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDevicePrimaryCtxRelease_v2(dev: CUdevice) -> CUresult {
- r#impl::unimplemented()
+ r#impl::device::primary_ctx_release_v2(dev.decuda())
}
#[cfg_attr(not(test), no_mangle)]
diff --git a/notcuda/src/impl/device.rs b/notcuda/src/impl/device.rs
index b8d263d..5a399dc 100644
--- a/notcuda/src/impl/device.rs
+++ b/notcuda/src/impl/device.rs
@@ -345,6 +345,11 @@ pub fn primary_ctx_retain(
Ok(())
}
+// TODO: allow for retain/reset/release of primary context
+pub(crate) fn primary_ctx_release_v2(_dev_idx: Index) -> CUresult {
+ CUresult::CUDA_SUCCESS
+}
+
#[cfg(test)]
mod test {
use super::super::test::CudaDriverFns;
diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs
index ae9f6e3..87d7f40 100644
--- a/notcuda/src/impl/export_table.rs
+++ b/notcuda/src/impl/export_table.rs
@@ -4,7 +4,7 @@ use crate::{
cuda_impl,
};
-use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
+use super::{context, context::ContextData, device, module, Decuda, Encuda, GlobalState};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@@ -110,8 +110,17 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [
VTableEntry { ptr: ptr::null() },
];
-unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
- super::unimplemented()
+unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
+ cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
+}
+
+fn cudart_interface_fn1_impl(
+ pctx: *mut *mut context::Context,
+ dev: device::Index,
+) -> Result<(), CUresult> {
+ let ctx_ptr = GlobalState::lock_device(dev, |d| &mut d.primary_context as *mut _)?;
+ unsafe { *pctx = ctx_ptr };
+ Ok(())
}
/*
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index 4422107..e19d8de 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -110,7 +110,6 @@ pub fn get_function(
entry.insert(new_module)
}
};
- //let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
let kernel = match compiled_module.kernels.entry(name) {
hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
hash_map::Entry::Vacant(entry) => {
@@ -121,8 +120,13 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
- let kernel =
+ let mut kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
+ kernel.set_indirect_access(
+ l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
+ | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
+ | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED
+ )?;
entry.insert(Box::new(Function::new(FunctionData {
base: kernel,
arg_size: kernel_info.arguments_sizes.clone(),